diff --git a/scripts/_mkl/notebooks/00a - Types.ipynb b/scripts/_mkl/notebooks/00a - Types.ipynb index 1edbdd39..0a043c74 100644 --- a/scripts/_mkl/notebooks/00a - Types.ipynb +++ b/scripts/_mkl/notebooks/00a - Types.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp types" + "# |default_exp types" ] }, { @@ -15,7 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from typing import Any, NamedTuple\n", "import numpy as np\n", "import jax\n", @@ -29,18 +29,18 @@ "Int = Array\n", "FaceIndex = int\n", "FaceIndices = Array\n", - "ArrayN = Array\n", - "Array3 = Array\n", - "Array2 = Array\n", - "ArrayNx2 = Array\n", - "ArrayNx3 = Array\n", - "Matrix = jaxlib.xla_extension.ArrayImpl\n", - "PrecisionMatrix = Matrix\n", + "ArrayN = Array\n", + "Array3 = Array\n", + "Array2 = Array\n", + "ArrayNx2 = Array\n", + "ArrayNx3 = Array\n", + "Matrix = jaxlib.xla_extension.ArrayImpl\n", + "PrecisionMatrix = Matrix\n", "CovarianceMatrix = Matrix\n", - "CholeskyMatrix = Matrix\n", - "SquareMatrix = Matrix\n", - "Vector = Array\n", - "Direction = Vector\n", + "CholeskyMatrix = Matrix\n", + "SquareMatrix = Matrix\n", + "Vector = Array\n", + "Direction = Vector\n", "BaseVector = Vector" ] }, diff --git a/scripts/_mkl/notebooks/00b - Utils.ipynb b/scripts/_mkl/notebooks/00b - Utils.ipynb index d0a714cf..1255b59b 100644 --- a/scripts/_mkl/notebooks/00b - Utils.ipynb +++ b/scripts/_mkl/notebooks/00b - Utils.ipynb @@ -13,7 +13,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp utils" + "# |default_exp utils" ] }, { @@ -22,9 +22,9 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import matplotlib.pyplot as plt\n", - "from matplotlib.collections import LineCollection\n", + "from matplotlib.collections import LineCollection\n", "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", @@ -44,8 +44,8 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "key = jax.random.PRNGKey(0)\n", + "# |export\n", + "key = jax.random.PRNGKey(0)\n", "logsumexp = jax.scipy.special.logsumexp" ] }, @@ -55,18 +55,21 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def keysplit(key, *ns):\n", - " if len(ns) == 0: \n", + " if len(ns) == 0:\n", " return jax.random.split(key, 1)[0]\n", " elif len(ns) == 1:\n", - " n, = ns\n", - " if n == 1: return keysplit(key)\n", - " else: return jax.random.split(key, ns[0])\n", + " (n,) = ns\n", + " if n == 1:\n", + " return keysplit(key)\n", + " else:\n", + " return jax.random.split(key, ns[0])\n", " else:\n", " keys = []\n", - " for n in ns: keys.append(keysplit(key, n))\n", - " return keys\n" + " for n in ns:\n", + " keys.append(keysplit(key, n))\n", + " return keys" ] }, { @@ -122,13 +125,15 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def bounding_box(arr, pad=0):\n", " \"\"\"Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box.\"\"\"\n", - " return jnp.array([\n", - " [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad],\n", - " [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad]\n", - " ])" + " return jnp.array(\n", + " [\n", + " [jnp.min(arr[..., 0]) - pad, jnp.min(arr[..., 1]) - pad],\n", + " [jnp.max(arr[..., 0]) + pad, jnp.max(arr[..., 1]) + pad],\n", + " ]\n", + " )" ] }, { @@ -137,24 +142,27 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def argmax_axes(a, axes=None):\n", " \"\"\"Argmax along specified axes\"\"\"\n", - " if axes is None: return jnp.argmax(a)\n", - " \n", - " n = len(axes) \n", - " axes_ = set(range(a.ndim))\n", + " if axes is None:\n", + " return jnp.argmax(a)\n", + "\n", + " n = len(axes)\n", + " axes_ = set(range(a.ndim))\n", " axes_0 = axes\n", - " axes_1 = sorted(axes_ - set(axes_0)) \n", - " axes_ = axes_0 + axes_1\n", + " axes_1 = sorted(axes_ - set(axes_0))\n", + " axes_ = axes_0 + axes_1\n", "\n", " b = jnp.transpose(a, axes=axes_)\n", " c = b.reshape(np.prod(b.shape[:n]), -1)\n", "\n", " I = jnp.argmax(c, axis=0)\n", - " I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,))\n", + " I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(\n", + " b.shape[n:] + (n,)\n", + " )\n", "\n", - " return I" + " return I" ] }, { @@ -177,7 +185,7 @@ "test_shape = (3, 99, 5, 9)\n", "a = jnp.arange(np.prod(test_shape)).reshape(test_shape)\n", "\n", - "I = argmax_axes(a, axes=[0,1])\n", + "I = argmax_axes(a, axes=[0, 1])\n", "I.shape" ] }, @@ -194,9 +202,13 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def cam_to_screen(x): return jnp.array([x[0]/x[2], x[1]/x[2], jnp.linalg.norm(x)])\n", - "def screen_to_cam(y): return y[2]*jnp.array([y[0], y[1], 1.0])" + "# |export\n", + "def cam_to_screen(x):\n", + " return jnp.array([x[0] / x[2], x[1] / x[2], jnp.linalg.norm(x)])\n", + "\n", + "\n", + "def screen_to_cam(y):\n", + " return y[2] * jnp.array([y[0], y[1], 1.0])" ] }, { @@ -205,24 +217,26 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def rot2d(hd): return jnp.array([\n", - " [jnp.cos(hd), -jnp.sin(hd)], \n", - " [jnp.sin(hd), jnp.cos(hd)]\n", - " ]);\n", + "# |export\n", + "def rot2d(hd):\n", + " return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]])\n", + "\n", "\n", - "def pack_2dpose(x,hd): \n", - " return jnp.concatenate([x,jnp.array([hd])])\n", + "def pack_2dpose(x, hd):\n", + " return jnp.concatenate([x, jnp.array([hd])])\n", "\n", - "def apply_2dpose(p, ys): \n", - " return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]\n", "\n", - "def unit_vec(hd): \n", + "def apply_2dpose(p, ys):\n", + " return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2]\n", + "\n", + "\n", + "def unit_vec(hd):\n", " return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n", "\n", + "\n", "def adjust_angle(hd):\n", " \"\"\"Adjusts angle to lie in the interval [-pi,pi).\"\"\"\n", - " return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi" + " return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi" ] }, { @@ -238,12 +252,12 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from genjax.incremental import UnknownChange, NoChange, Diff\n", "\n", "\n", "def argdiffs(args, other=None):\n", - " return tuple(map(lambda v: Diff(v, UnknownChange), args))\n" + " return tuple(map(lambda v: Diff(v, UnknownChange), args))" ] }, { @@ -252,7 +266,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from builtins import property as _property, tuple as _tuple\n", "from typing import Any\n", "\n", @@ -260,10 +274,10 @@ "class Args(tuple):\n", " def __new__(cls, *args, **kwargs):\n", " return _tuple.__new__(cls, list(args) + list(kwargs.values()))\n", - " \n", + "\n", " def __init__(self, *args, **kwargs):\n", " self._d = dict()\n", - " for k,v in kwargs.items():\n", + " for k, v in kwargs.items():\n", " self._d[k] = v\n", " setattr(self, k, v)\n", "\n", @@ -297,30 +311,35 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "# \n", + "# |export\n", + "#\n", "# Monkey patching `sample` for `BuiltinGenerativeFunction`\n", - "# \n", + "#\n", "cls = genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction\n", "\n", + "\n", "def genjax_sample(self, key, *args, **kwargs):\n", " tr = self.simulate(key, args)\n", " return tr.get_retval()\n", "\n", + "\n", "setattr(cls, \"sample\", genjax_sample)\n", "\n", "\n", - "# \n", + "#\n", "# Monkey patching `sample` for `DeferredGenerativeFunctionCall`\n", - "# \n", + "#\n", "cls = genjax._src.generative_functions.supports_callees.SugaredGenerativeFunctionCall\n", "\n", + "\n", "def deff_gen_func_call(self, key, **kwargs):\n", " return self.gen_fn.sample(key, *self.args, **kwargs)\n", "\n", + "\n", "def deff_gen_func_logpdf(self, x, **kwargs):\n", " return self.gen_fn.logpdf(x, *self.args, **kwargs)\n", "\n", + "\n", "setattr(cls, \"__call__\", deff_gen_func_call)\n", "setattr(cls, \"sample\", deff_gen_func_call)\n", "setattr(cls, \"logpdf\", deff_gen_func_logpdf)" diff --git a/scripts/_mkl/notebooks/01 - Plotting.ipynb b/scripts/_mkl/notebooks/01 - Plotting.ipynb index fd79619e..9956af71 100644 --- a/scripts/_mkl/notebooks/01 - Plotting.ipynb +++ b/scripts/_mkl/notebooks/01 - Plotting.ipynb @@ -13,7 +13,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp plotting" + "# |default_exp plotting" ] }, { @@ -22,7 +22,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import jax.numpy as jnp\n", @@ -35,21 +35,23 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def rgba_from_vals(vs, q=0.0, cmap=\"viridis\", vmin=None, vmax=None):\n", - " if isinstance(q,list):\n", + " if isinstance(q, list):\n", " v_min = np.quantile(vs, q[0])\n", " v_max = np.quantile(vs, q[1])\n", " else:\n", " v_min = np.quantile(vs, q)\n", " v_max = np.max(vs)\n", "\n", - " if vmax is not None: v_max = vmax\n", - " if vmin is not None: v_min = vmin\n", + " if vmax is not None:\n", + " v_max = vmax\n", + " if vmin is not None:\n", + " v_min = vmin\n", "\n", - " cm = getattr(plt.cm, cmap)\n", + " cm = getattr(plt.cm, cmap)\n", " vs_ = np.clip(vs, v_min, v_max)\n", - " cs = cm(plt.Normalize()(vs_))\n", + " cs = cm(plt.Normalize()(vs_))\n", " return cs" ] }, @@ -83,7 +85,7 @@ "vs = np.random.randn(10000)\n", "cs = rgba_from_vals(vs, q=0.0, cmap=\"binary\")\n", "\n", - "plt.figure(figsize=(2,1))\n", + "plt.figure(figsize=(2, 1))\n", "plt.scatter(vs, np.zeros(len(vs)), c=cs)\n", "np.quantile(vs, 0.0) == np.min(vs)" ] @@ -94,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from matplotlib.collections import LineCollection\n", "\n", "\n", @@ -132,8 +134,8 @@ ], "source": [ "n = 10\n", - "a = np.random.randn(n,2)\n", - "b = np.random.randn(n,2)\n", + "a = np.random.randn(n, 2)\n", + "b = np.random.randn(n, 2)\n", "v = np.random.randn(n)\n", "c = rgba_from_vals(v, q=0.0, cmap=\"viridis\")\n", "lc = line_collection(a, b, linewidth=2)\n", @@ -141,7 +143,7 @@ "\n", "\n", "# -------------------\n", - "plt.figure(figsize=(2,2))\n", + "plt.figure(figsize=(2, 2))\n", "plt.gca().set_aspect(1)\n", "plt.gca().add_collection(lc)\n", "plt.scatter(*a.T, c=c)\n", @@ -154,13 +156,14 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def plot_segs(segs, c=\"k\", linewidth=1, ax=None, **kwargs):\n", - " if ax is None: ax = plt.gca()\n", + "# |export\n", + "def plot_segs(segs, c=\"k\", linewidth=1, ax=None, **kwargs):\n", + " if ax is None:\n", + " ax = plt.gca()\n", " n = 10\n", - " segs = segs.reshape(-1,2,2)\n", - " a = segs[:,0]\n", - " b = segs[:,1]\n", + " segs = segs.reshape(-1, 2, 2)\n", + " a = segs[:, 0]\n", + " b = segs[:, 1]\n", " lc = line_collection(a, b, linewidth=linewidth, **kwargs)\n", " lc.set_colors(c)\n", " ax.add_collection(lc)" @@ -193,13 +196,13 @@ } ], "source": [ - "segs = np.stack([a,b],axis=1) \n", + "segs = np.stack([a, b], axis=1)\n", "# -------------------\n", - "plt.figure(figsize=(2,2))\n", + "plt.figure(figsize=(2, 2))\n", "plt.gca().set_aspect(1)\n", "plot_segs(segs, c=\"b\", linewidth=1, zorder=-1)\n", "plt.scatter(*a.T, c=c)\n", - "plt.scatter(*b.T, c=c)\n" + "plt.scatter(*b.T, c=c)" ] }, { @@ -208,11 +211,12 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def zoom_in(x, pad, ax=None):\n", - " if ax is None: ax = plt.gca()\n", - " ax.set_xlim(np.min(x[...,0])-pad, np.max(x[...,0])+pad)\n", - " ax.set_ylim(np.min(x[...,1])-pad, np.max(x[...,1])+pad)" + " if ax is None:\n", + " ax = plt.gca()\n", + " ax.set_xlim(np.min(x[..., 0]) - pad, np.max(x[..., 0]) + pad)\n", + " ax.set_ylim(np.min(x[..., 1]) - pad, np.max(x[..., 1]) + pad)" ] }, { @@ -221,25 +225,39 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def unit_vec(hd): \n", + "# |export\n", + "def unit_vec(hd):\n", " return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n", "\n", - "def plot_poses(ps, sc=None, r=0.5, clip=-1e12, cs=None, c=\"lightgray\", cmap=\"viridis\", ax=None, q=0.0, zorder=None, linewidth=2):\n", - " if ax is None: ax = plt.gca()\n", + "\n", + "def plot_poses(\n", + " ps,\n", + " sc=None,\n", + " r=0.5,\n", + " clip=-1e12,\n", + " cs=None,\n", + " c=\"lightgray\",\n", + " cmap=\"viridis\",\n", + " ax=None,\n", + " q=0.0,\n", + " zorder=None,\n", + " linewidth=2,\n", + "):\n", + " if ax is None:\n", + " ax = plt.gca()\n", " ax.set_aspect(1)\n", - " ps = ps.reshape(-1,3)\n", + " ps = ps.reshape(-1, 3)\n", "\n", - " a = ps[:,:2]\n", - " b = a + r * jax.vmap(unit_vec)(ps[:,2])\n", + " a = ps[:, :2]\n", + " b = a + r * jax.vmap(unit_vec)(ps[:, 2])\n", "\n", " if cs is None:\n", " if sc is None:\n", " cs = c\n", " else:\n", " sc = sc.reshape(-1)\n", - " sc = jnp.where(jnp==-jnp.inf, clip, sc)\n", - " sc = jnp.clip(sc, clip, jnp.max(sc))\n", + " sc = jnp.where(jnp == -jnp.inf, clip, sc)\n", + " sc = jnp.clip(sc, clip, jnp.max(sc))\n", " sc = jnp.clip(sc, jnp.quantile(sc, q), jnp.max(sc))\n", " cs = getattr(plt.cm, cmap)(plt.Normalize()(sc))\n", "\n", @@ -248,9 +266,7 @@ " b = b[order]\n", " cs = cs[order]\n", "\n", - "\n", - "\n", - " ax.add_collection(line_collection(a,b, c=cs, zorder=zorder, linewidth=linewidth));" + " ax.add_collection(line_collection(a, b, c=cs, zorder=zorder, linewidth=linewidth))" ] }, { @@ -259,13 +275,14 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def plot_pose(p, r=0.5, c=\"red\", ax=None,zorder=None, linewidth=2):\n", - " if ax is None: ax = plt.gca()\n", + "# |export\n", + "def plot_pose(p, r=0.5, c=\"red\", ax=None, zorder=None, linewidth=2):\n", + " if ax is None:\n", + " ax = plt.gca()\n", " ax.set_aspect(1)\n", " a = p[:2]\n", - " b = a + r*unit_vec(p[2])\n", - " ax.plot([a[0],b[0]],[a[1],b[1]], c=c, zorder=zorder, linewidth=linewidth)\n" + " b = a + r * unit_vec(p[2])\n", + " ax.plot([a[0], b[0]], [a[1], b[1]], c=c, zorder=zorder, linewidth=linewidth)" ] }, { @@ -285,15 +302,15 @@ } ], "source": [ - "ps = jnp.linspace(0.0, 1., 10)[:,None]*jnp.array([1.,-.25,jnp.pi/2])\n", - "sc = ps[:,2]\n", + "ps = jnp.linspace(0.0, 1.0, 10)[:, None] * jnp.array([1.0, -0.25, jnp.pi / 2])\n", + "sc = ps[:, 2]\n", "# -------------------\n", - "plt.figure(figsize=(2,2))\n", - "plt.xlim(-2,2)\n", - "plt.ylim(-2,2)\n", + "plt.figure(figsize=(2, 2))\n", + "plt.xlim(-2, 2)\n", + "plt.ylim(-2, 2)\n", "plt.gca().set_aspect(1)\n", "plot_poses(ps, sc=sc, cmap=\"viridis\", ax=None, q=0.0, linewidth=2)\n", - "plot_pose(jnp.array([0,1,0]), r=0.5, c=\"magenta\", ax=None,zorder=None, linewidth=1)" + "plot_pose(jnp.array([0, 1, 0]), r=0.5, c=\"magenta\", ax=None, zorder=None, linewidth=1)" ] }, { diff --git a/scripts/_mkl/notebooks/02 - Pose.ipynb b/scripts/_mkl/notebooks/02 - Pose.ipynb index 21b14e60..8cc0d6f2 100644 --- a/scripts/_mkl/notebooks/02 - Pose.ipynb +++ b/scripts/_mkl/notebooks/02 - Pose.ipynb @@ -13,7 +13,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp pose" + "# |default_exp pose" ] }, { @@ -22,17 +22,17 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import jax\n", "import jax.numpy as jnp\n", "import genjax\n", - "from genjax.generative_functions.distributions import ExactDensity\n", + "from genjax.generative_functions.distributions import ExactDensity\n", "from dataclasses import dataclass\n", "from collections import namedtuple\n", "from plum import dispatch\n", "\n", - "PI = jnp.pi\n", - "TWOPI = 2*PI" + "PI = jnp.pi\n", + "TWOPI = 2 * PI" ] }, { @@ -80,23 +80,25 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def rot2d(hd): return jnp.array([\n", - " [jnp.cos(hd), -jnp.sin(hd)], \n", - " [jnp.sin(hd), jnp.cos(hd)]\n", - " ]);\n", + "# |export\n", + "def rot2d(hd):\n", + " return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]])\n", "\n", - "def pack_2dpose(x,hd): \n", - " return jnp.concatenate([x,jnp.array([hd])])\n", "\n", - "def apply_2dpose(p, ys): \n", - " return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]\n", + "def pack_2dpose(x, hd):\n", + " return jnp.concatenate([x, jnp.array([hd])])\n", "\n", - "def unit_vec(hd): \n", + "\n", + "def apply_2dpose(p, ys):\n", + " return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2]\n", + "\n", + "\n", + "def unit_vec(hd):\n", " return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n", "\n", - "def adjust_angle(hd): \n", - " return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi" + "\n", + "def adjust_angle(hd):\n", + " return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi" ] }, { @@ -162,43 +164,45 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "CAM_ALONG_X = jnp.array([\n", - " [0, 0, 1],\n", - " [-1, 0, 0],\n", - " [0, -1, 0]\n", - "])\n", + "# |export\n", + "CAM_ALONG_X = jnp.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])\n", "\n", "\n", "def rot_x(theta):\n", - " return jnp.array([\n", - " [1, 0, 0],\n", - " [0, jnp.cos(theta), -jnp.sin(theta)],\n", - " [0, jnp.sin(theta), jnp.cos(theta)]\n", - " ])\n", + " return jnp.array(\n", + " [\n", + " [1, 0, 0],\n", + " [0, jnp.cos(theta), -jnp.sin(theta)],\n", + " [0, jnp.sin(theta), jnp.cos(theta)],\n", + " ]\n", + " )\n", "\n", "\n", "def rot_y(theta):\n", - " return jnp.array([\n", - " [jnp.cos(theta), 0, -jnp.sin(theta)],\n", - " [0, 1, 0],\n", - " [jnp.sin(theta), 0, jnp.cos(theta)]\n", - " ])\n", + " return jnp.array(\n", + " [\n", + " [jnp.cos(theta), 0, -jnp.sin(theta)],\n", + " [0, 1, 0],\n", + " [jnp.sin(theta), 0, jnp.cos(theta)],\n", + " ]\n", + " )\n", "\n", "\n", "def rot_z(theta):\n", - " return jnp.array([\n", - " [jnp.cos(theta), -jnp.sin(theta), 0],\n", - " [jnp.sin(theta), jnp.cos(theta), 0],\n", - " [0, 0, 1]\n", - " ])\n", + " return jnp.array(\n", + " [\n", + " [jnp.cos(theta), -jnp.sin(theta), 0],\n", + " [jnp.sin(theta), jnp.cos(theta), 0],\n", + " [0, 0, 1],\n", + " ]\n", + " )\n", "\n", "\n", "def from_euler(rot, pitch=0.0, roll=0.0):\n", " \"\"\"\n", " Imagine you stand on xy-plane and rotate (z-axis), pitch (y'-axis), and roll (x''-axis).\n", " \"\"\"\n", - " return rot_z(rot)@rot_y(pitch)@rot_x(roll)\n", + " return rot_z(rot) @ rot_y(pitch) @ rot_x(roll)\n", "\n", "\n", "def look_at(v, roll=0.0, cam=True):\n", @@ -206,9 +210,9 @@ " R = CAM_ALONG_X if cam else jnp.eye(3)\n", "\n", " n = jnp.linalg.norm(v)\n", - " rot = jnp.arctan2(v[1],v[0])\n", - " pitch = jnp.arctan2(v[2],n)\n", - " return from_euler(rot, pitch, roll)@R" + " rot = jnp.arctan2(v[1], v[0])\n", + " pitch = jnp.arctan2(v[2], n)\n", + " return from_euler(rot, pitch, roll) @ R" ] }, { @@ -217,43 +221,52 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def ax_to_ind(c):\n", - " lookup = {\"x\":0, \"y\":1, \"z\":2}\n", + " lookup = {\"x\": 0, \"y\": 1, \"z\": 2}\n", " return lookup[c]\n", "\n", "\n", "class Rotation(object):\n", " @staticmethod\n", " def _x(theta):\n", - " return jnp.array([\n", - " [1, 0, 0],\n", - " [0, jnp.cos(theta), -jnp.sin(theta)],\n", - " [0, jnp.sin(theta), jnp.cos(theta)]\n", - " ])\n", + " return jnp.array(\n", + " [\n", + " [1, 0, 0],\n", + " [0, jnp.cos(theta), -jnp.sin(theta)],\n", + " [0, jnp.sin(theta), jnp.cos(theta)],\n", + " ]\n", + " )\n", "\n", " @staticmethod\n", " def _y(theta):\n", - " return jnp.array([\n", - " [jnp.cos(theta), 0, -jnp.sin(theta)],\n", - " [0, 1, 0],\n", - " [jnp.sin(theta), 0, jnp.cos(theta)]\n", - " ])\n", + " return jnp.array(\n", + " [\n", + " [jnp.cos(theta), 0, -jnp.sin(theta)],\n", + " [0, 1, 0],\n", + " [jnp.sin(theta), 0, jnp.cos(theta)],\n", + " ]\n", + " )\n", "\n", " @staticmethod\n", " def _z(theta):\n", - " return jnp.array([\n", - " [jnp.cos(theta), -jnp.sin(theta), 0],\n", - " [jnp.sin(theta), jnp.cos(theta), 0],\n", - " [0, 0, 1]\n", - " ])\n", + " return jnp.array(\n", + " [\n", + " [jnp.cos(theta), -jnp.sin(theta), 0],\n", + " [jnp.sin(theta), jnp.cos(theta), 0],\n", + " [0, 0, 1],\n", + " ]\n", + " )\n", "\n", " @staticmethod\n", - " def _ax(ax:str, theta):\n", - " if ax == \"x\": return Rotation._x(theta)\n", - " elif ax == \"y\": return Rotation._y(theta)\n", - " elif ax == \"z\": return Rotation._z(theta)\n", - " \n", + " def _ax(ax: str, theta):\n", + " if ax == \"x\":\n", + " return Rotation._x(theta)\n", + " elif ax == \"y\":\n", + " return Rotation._y(theta)\n", + " elif ax == \"z\":\n", + " return Rotation._z(theta)\n", + "\n", " @staticmethod\n", " def from_euler(order, angles):\n", " \"\"\"\n", @@ -262,18 +275,22 @@ " angles : Array of length 3, e.g. [0, 0, 0]\n", " \"\"\"\n", " rot_ax = Rotation._ax\n", - " return rot_ax(order[0], angles[0])@rot_ax(order[1],angles[1])@rot_ax(order[2],angles[2])\n", - " \n", + " return (\n", + " rot_ax(order[0], angles[0])\n", + " @ rot_ax(order[1], angles[1])\n", + " @ rot_ax(order[2], angles[2])\n", + " )\n", + "\n", " @staticmethod\n", " def look_at(v, roll=0.0, order=\"zyx\"):\n", - " \"\"\"\n", - " \"\"\"\n", + " \"\"\" \"\"\"\n", " n = jnp.linalg.norm(v)\n", - " rot = jnp.arctan2(v[1],v[0])\n", - " pitch = jnp.arctan2(v[2],n)\n", - " return Rotation.from_euler(order, [rot, pitch, roll])@yzX[:3,:3]\n", - " \n", - "Rot = Rotation\n" + " rot = jnp.arctan2(v[1], v[0])\n", + " pitch = jnp.arctan2(v[2], n)\n", + " return Rotation.from_euler(order, [rot, pitch, roll]) @ yzX[:3, :3]\n", + "\n", + "\n", + "Rot = Rotation" ] }, { @@ -293,11 +310,7 @@ } ], "source": [ - "(\n", - " jnp.linalg.det(rot_x(0.0)), \n", - " jnp.linalg.det(rot_y(0.0)), \n", - " jnp.linalg.det(rot_z(0.0)) \n", - ")" + "(jnp.linalg.det(rot_x(0.0)), jnp.linalg.det(rot_y(0.0)), jnp.linalg.det(rot_z(0.0)))" ] }, { @@ -313,25 +326,29 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "Pose = namedtuple(\"Pose\", [\"x\", \"r\"])\n", + "# |export\n", + "Pose = namedtuple(\"Pose\", [\"x\", \"r\"])\n", + "\n", "\n", "@dispatch\n", - "def unpack_pose(p:Pose): \n", + "def unpack_pose(p: Pose):\n", " return p.x, p.r\n", "\n", + "\n", "@dispatch\n", - "def unpack_pose(R:jnp.ndarray): \n", - " return R[:3,3], R[:3,:3]\n", + "def unpack_pose(R: jnp.ndarray):\n", + " return R[:3, 3], R[:3, :3]\n", + "\n", + "\n", + "def pack_pose(x, r):\n", + " return jnp.concatenate(\n", + " [jnp.concatenate([r, x[:, None]], axis=1), jnp.array([[0, 0, 0, 1]])], axis=0\n", + " )\n", "\n", - "def pack_pose(x, r): \n", - " return jnp.concatenate([\n", - " jnp.concatenate([r, x[:,None]], axis=1), \n", - " jnp.array([[0,0,0,1]])], axis=0)\n", "\n", "def apply_pose(p, x):\n", " t, r = unpack_pose(p)\n", - " return x@r.T + t" + " return x @ r.T + t" ] }, { @@ -340,23 +357,47 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "\n", - "def mpl_plot_pose(ax, p, s=0.1, length=0.1, normalize=True, **kwargs):\n", "\n", - " t,r = unpack_pose(p)\n", + "def mpl_plot_pose(ax, p, s=0.1, length=0.1, normalize=True, **kwargs):\n", + " t, r = unpack_pose(p)\n", " # Coordinate frame data\n", " origin = t\n", - " x_axis = s*r[:3,0]\n", - " y_axis = s*r[:3,1]\n", - " z_axis = s*r[:3,2]\n", + " x_axis = s * r[:3, 0]\n", + " y_axis = s * r[:3, 1]\n", + " z_axis = s * r[:3, 2]\n", "\n", " # Plotting the coordinate frame\n", - " ax.quiver(*origin, *x_axis, color='r', label='X-axis', length=length, normalize=normalize, **kwargs)\n", - " ax.quiver(*origin, *y_axis, color='g', label='Y-axis', length=length, normalize=normalize, **kwargs)\n", - " ax.quiver(*origin, *z_axis, color='b', label='Z-axis', length=length, normalize=normalize, **kwargs)" + " ax.quiver(\n", + " *origin,\n", + " *x_axis,\n", + " color=\"r\",\n", + " label=\"X-axis\",\n", + " length=length,\n", + " normalize=normalize,\n", + " **kwargs,\n", + " )\n", + " ax.quiver(\n", + " *origin,\n", + " *y_axis,\n", + " color=\"g\",\n", + " label=\"Y-axis\",\n", + " length=length,\n", + " normalize=normalize,\n", + " **kwargs,\n", + " )\n", + " ax.quiver(\n", + " *origin,\n", + " *z_axis,\n", + " color=\"b\",\n", + " label=\"Z-axis\",\n", + " length=length,\n", + " normalize=normalize,\n", + " **kwargs,\n", + " )" ] }, { @@ -374,16 +415,16 @@ } ], "source": [ - "#|hide\n", - "ps = jax.random.uniform(key, (100,4,4))\n", + "# |hide\n", + "ps = jax.random.uniform(key, (100, 4, 4))\n", "ps = ps.at[0].set(jnp.eye(4))\n", "\n", - "xs = jax.random.uniform(key, (100,3))\n", + "xs = jax.random.uniform(key, (100, 3))\n", "\n", "ys = jax.vmap(apply_pose)(ps, xs)\n", "print(ys.shape, jnp.all(ys[0] == xs[0]))\n", "\n", - "ys = jax.vmap(apply_pose, (0,None))(ps, xs)\n", + "ys = jax.vmap(apply_pose, (0, None))(ps, xs)\n", "print(ys.shape, jnp.all(ys[0] == xs))" ] }, @@ -393,13 +434,10 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def lift_pose(x, hd, z=0.0, pitch=0.0, roll=0.0):\n", " \"\"\"Lifts a 2d pose (x,hd) to 3d\"\"\"\n", - " return pack_pose(\n", - " jnp.concatenate([x, jnp.array([z])]),\n", - " from_euler(hd) @ CAM_ALONG_X\n", - " ) " + " return pack_pose(jnp.concatenate([x, jnp.array([z])]), from_euler(hd) @ CAM_ALONG_X)" ] }, { @@ -415,7 +453,7 @@ "# return pack_pose(\n", "# jnp.concatenate([x, jnp.array([z])]),\n", "# from_euler(hd) @ CAM_ALONG_X\n", - "# ) " + "# )" ] }, { @@ -436,22 +474,22 @@ ], "source": [ "ps = [\n", - " lift_pose(jnp.array([0.0,0.5]), 0.0, z=0.0),\n", - " lift_pose(jnp.array([0.0,0.5]), jnp.pi/4, z=.1),\n", - " lift_pose(jnp.array([0.0,0.5]), jnp.pi/2, z=.2),\n", - " lift_pose(jnp.array([0.0,0.5]), 3*jnp.pi/4, z=.3),\n", - " lift_pose(jnp.array([0.0,0.5]), jnp.pi, z=0.4)\n", + " lift_pose(jnp.array([0.0, 0.5]), 0.0, z=0.0),\n", + " lift_pose(jnp.array([0.0, 0.5]), jnp.pi / 4, z=0.1),\n", + " lift_pose(jnp.array([0.0, 0.5]), jnp.pi / 2, z=0.2),\n", + " lift_pose(jnp.array([0.0, 0.5]), 3 * jnp.pi / 4, z=0.3),\n", + " lift_pose(jnp.array([0.0, 0.5]), jnp.pi, z=0.4),\n", "]\n", "\n", "# -----------------\n", "fig = plt.figure()\n", - "ax = fig.add_subplot(111, projection='3d')\n", - "ax.set_box_aspect((1,1,1))\n", - "ax.set_xlim(-1,1)\n", - "ax.set_ylim(-1,1)\n", - "ax.set_zlim(-1,1)\n", + "ax = fig.add_subplot(111, projection=\"3d\")\n", + "ax.set_box_aspect((1, 1, 1))\n", + "ax.set_xlim(-1, 1)\n", + "ax.set_ylim(-1, 1)\n", + "ax.set_zlim(-1, 1)\n", "for p in ps:\n", - " mpl_plot_pose(ax, p, s=.1, length=.1)\n", + " mpl_plot_pose(ax, p, s=0.1, length=0.1)\n", "fig.show();" ] }, @@ -479,12 +517,12 @@ } ], "source": [ - "#|hide\n", - "xs = jax.random.normal(key, (100,2))\n", - "hds = jnp.pi*2*jax.random.normal(key, (100,))\n", + "# |hide\n", + "xs = jax.random.normal(key, (100, 2))\n", + "hds = jnp.pi * 2 * jax.random.normal(key, (100,))\n", "\n", - "ps = jax.vmap(lift_pose, (0,0,None))(xs, hds, 1.0)\n", - "ps.shape\n" + "ps = jax.vmap(lift_pose, (0, 0, None))(xs, hds, 1.0)\n", + "ps.shape" ] }, { @@ -505,15 +543,24 @@ ], "source": [ "n = 10\n", - "xs = jnp.stack([jnp.linspace(-1,1,n),jnp.zeros(n)], axis=1)\n", - "y = jnp.array([0.5,1.0])\n", - "hds = jnp.arctan2(y[1] - xs[:,1], y[0] - xs[:,0])\n", + "xs = jnp.stack([jnp.linspace(-1, 1, n), jnp.zeros(n)], axis=1)\n", + "y = jnp.array([0.5, 1.0])\n", + "hds = jnp.arctan2(y[1] - xs[:, 1], y[0] - xs[:, 0])\n", "\n", "# -----------------\n", - "plt.figure(figsize=(2,2))\n", + "plt.figure(figsize=(2, 2))\n", "plt.gca().set_aspect(\"equal\")\n", "plt.scatter(*xs.T, marker=\"o\", c=\"b\")\n", - "plt.quiver(xs[:,0],xs[:,1], jnp.cos(hds), jnp.sin(hds), color='b', label='X-axis', scale=2., units=\"xy\")\n", + "plt.quiver(\n", + " xs[:, 0],\n", + " xs[:, 1],\n", + " jnp.cos(hds),\n", + " jnp.sin(hds),\n", + " color=\"b\",\n", + " label=\"X-axis\",\n", + " scale=2.0,\n", + " units=\"xy\",\n", + ")\n", "plt.scatter(*y, c=\"r\");" ] }, @@ -523,9 +570,8 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "y_ = jnp.concatenate([y, jnp.array([0.5])])\n", - "xs_ = jnp.concatenate([xs, jnp.zeros((n,1))], axis=1)" + "y_ = jnp.concatenate([y, jnp.array([0.5])])\n", + "xs_ = jnp.concatenate([xs, jnp.zeros((n, 1))], axis=1)" ] }, { @@ -557,14 +603,14 @@ "print(y_)\n", "# -----------------\n", "fig = plt.figure()\n", - "ax = fig.add_subplot(111, projection='3d')\n", - "ax.set_box_aspect((1,1,0.5))\n", - "ax.set_xlim(-1,1)\n", - "ax.set_ylim(-1,1)\n", - "ax.set_zlim(-0,1)\n", + "ax = fig.add_subplot(111, projection=\"3d\")\n", + "ax.set_box_aspect((1, 1, 0.5))\n", + "ax.set_xlim(-1, 1)\n", + "ax.set_ylim(-1, 1)\n", + "ax.set_zlim(-0, 1)\n", "ax.scatter(*y_, c=\"red\")\n", "for p in ps:\n", - " mpl_plot_pose(ax, p, s=.1, length=.1)\n", + " mpl_plot_pose(ax, p, s=0.1, length=0.1)\n", "fig.show();" ] }, @@ -592,22 +638,19 @@ } ], "source": [ - "ps = jax.vmap(\n", - " lambda x,y: pack_pose(x, look_at(y-x)), \n", - " (0,None)\n", - " ) (xs_, y_)\n", + "ps = jax.vmap(lambda x, y: pack_pose(x, look_at(y - x)), (0, None))(xs_, y_)\n", "\n", "print(ps.shape)\n", "# -----------------\n", "fig = plt.figure()\n", - "ax = fig.add_subplot(111, projection='3d')\n", - "ax.set_box_aspect((1,1,0.5))\n", - "ax.set_xlim(-1,1)\n", - "ax.set_ylim(-1,1)\n", - "ax.set_zlim(-0,1)\n", + "ax = fig.add_subplot(111, projection=\"3d\")\n", + "ax.set_box_aspect((1, 1, 0.5))\n", + "ax.set_xlim(-1, 1)\n", + "ax.set_ylim(-1, 1)\n", + "ax.set_zlim(-0, 1)\n", "ax.scatter(*y_, c=\"red\")\n", "for p in ps:\n", - " mpl_plot_pose(ax, p, s=.1, length=.1)\n", + " mpl_plot_pose(ax, p, s=0.1, length=0.1)\n", "fig.show();" ] }, diff --git a/scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb b/scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb index 5bcd1d0d..217692da 100644 --- a/scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb +++ b/scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb @@ -16,7 +16,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp trimesh_to_gaussians" + "# |default_exp trimesh_to_gaussians" ] }, { @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "_doc_ = \"\"\"\n", "# Trimesh to Gaussians\n", "> Pretty much self-explanatory\n", @@ -101,7 +101,7 @@ } ], "source": [ - "#|export\n", + "# |export\n", "import bayes3d as b3d\n", "import trimesh\n", "from bayes3d._mkl.utils import *\n", @@ -118,14 +118,14 @@ "Shape = int | tuple[int, ...]\n", "FaceIndex = int\n", "FaceIndices = Array\n", - "Array3 = Array\n", - "Array2 = Array\n", - "ArrayNx2 = Array\n", - "ArrayNx3 = Array\n", - "Matrix = jaxlib.xla_extension.ArrayImpl\n", - "PrecisionMatrix = Matrix\n", + "Array3 = Array\n", + "Array2 = Array\n", + "ArrayNx2 = Array\n", + "ArrayNx3 = Array\n", + "Matrix = jaxlib.xla_extension.ArrayImpl\n", + "PrecisionMatrix = Matrix\n", "CovarianceMatrix = Matrix\n", - "SquareMatrix = Matrix\n", + "SquareMatrix = Matrix\n", "Vector = Array" ] }, @@ -144,14 +144,14 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def area_of_triangle(a:Array3, b:Array3, c:Array3=jnp.zeros(3)):\n", + "# |export\n", + "def area_of_triangle(a: Array3, b: Array3, c: Array3 = jnp.zeros(3)):\n", " \"\"\"Computes the area of a triangle spanned by a,b[,c].\"\"\"\n", - " x = a-c\n", - " y = b-c\n", + " x = a - c\n", + " y = b - c\n", " w = jnp.linalg.norm(x)\n", - " h = jnp.linalg.norm(y - jnp.dot(x, y)/w**2 * x)\n", - " area = w*h/2\n", + " h = jnp.linalg.norm(y - jnp.dot(x, y) / w**2 * x)\n", + " area = w * h / 2\n", "\n", " return area\n", "\n", @@ -160,11 +160,11 @@ " a = vertices[f[1]] - vertices[f[0]]\n", " b = vertices[f[2]] - vertices[f[0]]\n", " area = area_of_triangle(a, b)\n", - " normal = jnp.cross(a,b)\n", + " normal = jnp.cross(a, b)\n", " return area, normal\n", "\n", "\n", - "compute_area_and_normals = jit(vmap(_compute_area_and_normal, (0,None)))" + "compute_area_and_normals = jit(vmap(_compute_area_and_normal, (0, None)))" ] }, { @@ -173,10 +173,10 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def patch_trimesh(mesh:trimesh.base.Trimesh):\n", + "# |export\n", + "def patch_trimesh(mesh: trimesh.base.Trimesh):\n", " \"\"\"\n", - " Return a patched copy of a trimesh object, and \n", + " Return a patched copy of a trimesh object, and\n", " ensure it to have a texture and the following attributes:\n", " - `mesh.visual.uv`\n", " - `copy.visual.material.to_color`\n", @@ -190,15 +190,15 @@ " return patched_mesh\n", "\n", "\n", - "def texture_uv_basis(face_idx:Array, mesh):\n", + "def texture_uv_basis(face_idx: Array, mesh):\n", " \"\"\"\n", - " Takes a face index and returns the three uv-vectors \n", + " Takes a face index and returns the three uv-vectors\n", " spanning the face in texture space.\n", " \"\"\"\n", " return mesh.visual.uv[mesh.faces[face_idx]]\n", "\n", "\n", - "def uv_to_color(uv:ArrayNx2, mesh):\n", + "def uv_to_color(uv: ArrayNx2, mesh):\n", " \"\"\"Takes texture-uv coordinates and returns the corresponding color.\"\"\"\n", " return mesh.visual.material.to_color(uv)" ] @@ -209,50 +209,50 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def barycentric_to_mesh(p:Array3, i:FaceIndex, mesh):\n", + "# |export\n", + "def barycentric_to_mesh(p: Array3, i: FaceIndex, mesh):\n", " \"\"\"Converts a point in barycentric coordinates `p` on a face `i` to a 3d point on the mesh.\"\"\"\n", - " x = jnp.sum(p[:,None]*mesh.vertices[mesh.faces[i]], axis=0)\n", + " x = jnp.sum(p[:, None] * mesh.vertices[mesh.faces[i]], axis=0)\n", " return x\n", "\n", "\n", "def sample_from_face(key, n, i, mesh):\n", " \"\"\"\n", - " Sample random points `xs`, barycentric coordinates `ps`, and \n", + " Sample random points `xs`, barycentric coordinates `ps`, and\n", " face indices `fs` from a mesh.\n", " \"\"\"\n", - " _, key = keysplit(key,1,1)\n", - " ps = jax.random.dirichlet(key, jnp.ones(3), (n,)).reshape((n,3,1))\n", - " xs = jnp.sum(ps*mesh.vertices[mesh.faces[i]], axis=1)\n", + " _, key = keysplit(key, 1, 1)\n", + " ps = jax.random.dirichlet(key, jnp.ones(3), (n,)).reshape((n, 3, 1))\n", + " xs = jnp.sum(ps * mesh.vertices[mesh.faces[i]], axis=1)\n", " return xs, ps\n", "\n", "\n", "def sample_from_mesh(key, n, mesh):\n", " \"\"\"\n", - " Returns random points `xs`, barycentric coordinates `ps`, and \n", + " Returns random points `xs`, barycentric coordinates `ps`, and\n", " face indices `fs` from a mesh.\n", " \"\"\"\n", - " _, keys = keysplit(key,1,2)\n", + " _, keys = keysplit(key, 1, 2)\n", "\n", - " # Sample `n` faces from the mesh with \n", - " # probability proportional to their area. \n", + " # Sample `n` faces from the mesh with\n", + " # probability proportional to their area.\n", " areas, _ = compute_area_and_normals(mesh.faces, mesh.vertices)\n", " fs = jax.random.categorical(keys[0], jnp.log(areas), shape=(n,))\n", "\n", " # Sample barycentric coordinates `bs` for each sampled face\n", " # and compute the corresponding world coordinates `xs`.\n", - " ps = jax.random.dirichlet(keys[1], jnp.ones(3), (n,)).reshape((n,3,1))\n", - " xs = jnp.sum(ps*mesh.vertices[mesh.faces[fs]], axis=1)\n", + " ps = jax.random.dirichlet(keys[1], jnp.ones(3), (n,)).reshape((n, 3, 1))\n", + " xs = jnp.sum(ps * mesh.vertices[mesh.faces[fs]], axis=1)\n", " return xs, ps, fs\n", - " \n", "\n", - "def get_colors_from_mesh(ps:ArrayNx3, fs:FaceIndices, mesh):\n", + "\n", + "def get_colors_from_mesh(ps: ArrayNx3, fs: FaceIndices, mesh):\n", " \"\"\"\n", - " Returns the colors of the points on the mesh given \n", + " Returns the colors of the points on the mesh given\n", " their barycentric coordinates `ps` and face indices `fs`.\n", " \"\"\"\n", " uvs = jnp.sum(ps * texture_uv_basis(fs, mesh), axis=1)\n", - " cs = uv_to_color(uvs, mesh)/255\n", + " cs = uv_to_color(uvs, mesh) / 255\n", " return cs" ] }, @@ -284,7 +284,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def uniformly_sample_from_mesh(key, n, mesh, with_color=True):\n", " \"\"\"Uniformly sample `n` points and optionally their color on the surface from a mesh.\"\"\"\n", " xs, ps, fs = sample_from_mesh(key, n, mesh)\n", @@ -292,7 +292,7 @@ " if with_color:\n", " cs = get_colors_from_mesh(ps, fs, mesh)\n", " else:\n", - " cs = jnp.full((n,3), 0.5)\n", + " cs = jnp.full((n, 3), 0.5)\n", "\n", " return xs, cs" ] @@ -303,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def get_cluster_counts(m, labels):\n", " nums = []\n", " for label in range(m):\n", @@ -311,7 +311,7 @@ " return np.array(nums)\n", "\n", "\n", - "#|export\n", + "# |export\n", "def get_cluster_colors(cs, m, labels):\n", " colors = []\n", " for label in range(m):\n", @@ -325,22 +325,21 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def get_mean_colors(cs, n, labels):\n", " mean_colors = []\n", - " nums = []\n", + " nums = []\n", " for label in range(n):\n", " idx = labels == label\n", " num = np.sum(idx)\n", - " if num == 0: \n", + " if num == 0:\n", " c = np.array([0.5, 0.5, 0.5, 0.0])\n", - " else: \n", + " else:\n", " c = np.mean(cs[idx], axis=0)\n", " nums.append(num)\n", " mean_colors.append(c)\n", "\n", - " return np.array(mean_colors), np.array(nums)\n", - " " + " return np.array(mean_colors), np.array(nums)" ] }, { @@ -349,12 +348,12 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def ellipsoid_embedding(cov:CovarianceMatrix) -> Matrix:\n", + "# |export\n", + "def ellipsoid_embedding(cov: CovarianceMatrix) -> Matrix:\n", " \"\"\"Returns A with cov = A@A.T\"\"\"\n", " sigma, U = jnp.linalg.eigh(cov)\n", " D = jnp.diag(jnp.sqrt(sigma))\n", - " return U @ D @ U.T\n" + " return U @ D @ U.T" ] }, { @@ -363,27 +362,33 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def pack_transform(x, A, scale=1.0):\n", - " B = scale*A\n", - " return jnp.array([\n", - " [B[0,0], B[0,1], B[0,2], x[0]], \n", - " [B[1,0], B[1,1], B[1,2], x[1]],\n", - " [B[2,0], B[2,1], B[2,2], x[2]],\n", - " [0.0, 0.0, 0.0, 1.0]\n", - " ]).T\n", - "\n", - "\n", - "def transform_from_gaussian(mu:Vector, cov:CovarianceMatrix=jnp.eye(3), scale=1.0) -> Matrix:\n", + " B = scale * A\n", + " return jnp.array(\n", + " [\n", + " [B[0, 0], B[0, 1], B[0, 2], x[0]],\n", + " [B[1, 0], B[1, 1], B[1, 2], x[1]],\n", + " [B[2, 0], B[2, 1], B[2, 2], x[2]],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " ]\n", + " ).T\n", + "\n", + "\n", + "def transform_from_gaussian(\n", + " mu: Vector, cov: CovarianceMatrix = jnp.eye(3), scale=1.0\n", + ") -> Matrix:\n", " \"\"\"Returns an affine linear transformation 4x4 matrix from a Gaussian.\"\"\"\n", " A = ellipsoid_embedding(cov)\n", " B = scale * A\n", - " return jnp.array([\n", - " [B[0,0], B[0,1], B[0,2], mu[0]], \n", - " [B[1,0], B[1,1], B[1,2], mu[1]],\n", - " [B[2,0], B[2,1], B[2,2], mu[2]],\n", - " [0.0, 0.0, 0.0, 1.0]\n", - " ]).T" + " return jnp.array(\n", + " [\n", + " [B[0, 0], B[0, 1], B[0, 2], mu[0]],\n", + " [B[1, 0], B[1, 1], B[1, 2], mu[1]],\n", + " [B[2, 0], B[2, 1], B[2, 2], mu[2]],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " ]\n", + " ).T" ] }, { @@ -392,7 +397,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def create_ellipsoid_trimesh(covariance_matrix, num_points=10, scale=0.02):\n", " # Create a sphere\n", " u = np.linspace(0, 2 * np.pi, num_points)\n", @@ -404,13 +409,15 @@ " # Transform the sphere to the ellipsoid\n", " sigma, U = np.linalg.eig(covariance_matrix)\n", " D = np.diag(np.sqrt(sigma))\n", - " ellipsoid = U @ D @ np.linalg.inv(U) @ np.vstack([x.flatten(), y.flatten(), z.flatten()])\n", + " ellipsoid = (\n", + " U @ D @ np.linalg.inv(U) @ np.vstack([x.flatten(), y.flatten(), z.flatten()])\n", + " )\n", "\n", " # Reshape the ellipsoid to match the shape of the original sphere vertices\n", " ellipsoid = ellipsoid.T.reshape(num_points, num_points, 3)\n", "\n", " # Create mesh data\n", - " mesh_vertices = scale*ellipsoid.reshape(-1, 3)\n", + " mesh_vertices = scale * ellipsoid.reshape(-1, 3)\n", " mesh_faces = []\n", " for i in range(num_points - 1):\n", " for j in range(num_points - 1):\n", @@ -448,9 +455,9 @@ "metadata": {}, "outputs": [], "source": [ - "t=0\n", + "t = 0\n", "fname = f\"data/flag_objs/flag_t_{t}.obj\"\n", - "mesh = trimesh.load(fname)\n", + "mesh = trimesh.load(fname)\n", "mesh = patch_trimesh(mesh)" ] }, @@ -461,17 +468,14 @@ "outputs": [], "source": [ "name = \"scissors\"\n", - "idx = {\n", - " \"banana\": 10,\n", - " \"scissors\": 17\n", - "}[name]\n", + "idx = {\"banana\": 10, \"scissors\": 17}[name]\n", "\n", "\n", - "_scaling = 1e-3\n", - "model_dir = os.path.join(b3d.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "_scaling = 1e-3\n", + "model_dir = os.path.join(b3d.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "mesh = trimesh.load(mesh_path)\n", - "mesh.vertices *= _scaling \n", + "mesh.vertices *= _scaling\n", "\n", "\n", "mesh = patch_trimesh(mesh)" @@ -560,17 +564,20 @@ "# ----------\n", "key = keysplit(key)\n", "n_components = 1_000\n", - "noise = 0.0; \n", - "X = xs + np.random.randn(*xs.shape)*noise\n", - "means_init = np.array(uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]);\n", - "\n", - "\n", + "noise = 0.0\n", + "X = xs + np.random.randn(*xs.shape) * noise\n", + "means_init = np.array(\n", + " uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]\n", + ")\n", "# Fit the GMM\n", "# -----------\n", - "gm = GaussianMixture(n_components=n_components, \n", - " tol=1e-3, max_iter=100, \n", - " covariance_type=\"full\", \n", - " means_init=means_init).fit(X)" + "gm = GaussianMixture(\n", + " n_components=n_components,\n", + " tol=1e-3,\n", + " max_iter=100,\n", + " covariance_type=\"full\",\n", + " means_init=means_init,\n", + ").fit(X)" ] }, { @@ -590,11 +597,11 @@ } ], "source": [ - "mus = gm.means_\n", - "covs = gm.covariances_\n", - "labels = gm.predict(X)\n", - "choleskys = vmap(ellipsoid_embedding)(np.array(covs))\n", - "transforms = vmap(pack_transform, (0,0,None))(mus, choleskys, 2.0)\n", + "mus = gm.means_\n", + "covs = gm.covariances_\n", + "labels = gm.predict(X)\n", + "choleskys = vmap(ellipsoid_embedding)(np.array(covs))\n", + "transforms = vmap(pack_transform, (0, 0, None))(mus, choleskys, 2.0)\n", "mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)\n", "valid = nums > 0\n", "sum(valid)" @@ -618,7 +625,13 @@ ], "source": [ "fname = f\"data/gaussian_examples/gaussians_{name}_{n_components}.npz\"\n", - "jnp.savez(fname, mus=mus[valid], covs=covs[valid], colors=mean_colors[valid], choleskys=choleskys[valid])\n", + "jnp.savez(\n", + " fname,\n", + " mus=mus[valid],\n", + " covs=covs[valid],\n", + " colors=mean_colors[valid],\n", + " choleskys=choleskys[valid],\n", + ")\n", "fname" ] }, @@ -637,8 +650,8 @@ "source": [ "import traceviz.client\n", "import numpy as np\n", - "from traceviz.proto import viz_pb2\n", - "import json\n" + "from traceviz.proto import viz_pb2\n", + "import json" ] }, { @@ -661,22 +674,31 @@ } ], "source": [ - "\n", "msg = viz_pb2.Message()\n", - "msg.pytree.MergeFrom(traceviz.client.to_pytree_msg({\"type\": \"setup\",}))\n", + "msg.pytree.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"type\": \"setup\",\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "\n", "msg = viz_pb2.Message()\n", - "msg.pytree.MergeFrom(traceviz.client.to_pytree_msg({\n", - " \"type\": \"gaussians\",\n", - " \"data\": {\n", - " \"transforms\": np.array(transforms[None, valid]),\n", - " 'colors': np.array(mean_colors[None, valid])\n", + "msg.pytree.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"type\": \"gaussians\",\n", + " \"data\": {\n", + " \"transforms\": np.array(transforms[None, valid]),\n", + " \"colors\": np.array(mean_colors[None, valid]),\n", + " },\n", " }\n", - " }))\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -692,7 +714,14 @@ "metadata": {}, "outputs": [], "source": [ - "jnp.savez(f\"data/gaussians_scissors_{mus.shape[0]}.npz\", mus=mus, covs=covs, choleskys=choleskys, mean_colors=mean_colors, nums=nums)" + "jnp.savez(\n", + " f\"data/gaussians_scissors_{mus.shape[0]}.npz\",\n", + " mus=mus,\n", + " covs=covs,\n", + " choleskys=choleskys,\n", + " mean_colors=mean_colors,\n", + " nums=nums,\n", + ")" ] }, { @@ -705,7 +734,7 @@ "\n", "n = 10_000\n", "xs, cs = uniformly_sample_from_mesh(key, n, mesh, with_color=True)\n", - "covs = jnp.tile(jnp.eye(3), (n,1,1))\n", + "covs = jnp.tile(jnp.eye(3), (n, 1, 1))\n", "\n", "# jnp.savez(f\"data/gaussians_test.npz\", mus=xs, covs=covs, colors=cs)" ] @@ -731,16 +760,20 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"spheres\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'centers': np.array(xs), \n", - " 'colors': np.array(cs),\n", - " 'scales': np.ones(n)*0.01,\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": np.array(xs),\n", + " \"colors\": np.array(cs),\n", + " \"scales\": np.ones(n) * 0.01,\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -796,16 +829,20 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"animated spheres\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'centers': np.array(Xs), \n", - " 'colors': np.array(Cs),\n", - " 'scales': np.ones((len(Xs),n))*0.004,\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": np.array(Xs),\n", + " \"colors\": np.array(Cs),\n", + " \"scales\": np.ones((len(Xs), n)) * 0.004,\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -825,22 +862,23 @@ } ], "source": [ - "transforms = vmap(transform_from_gaussian, (0,0,None))(xs, covs, .02)\n", + "transforms = vmap(transform_from_gaussian, (0, 0, None))(xs, covs, 0.02)\n", "\n", "\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"Gaussians2\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'transforms': np.array(transforms ), \n", - " 'colors': np.array(cs)\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\"transforms\": np.array(transforms), \"colors\": np.array(cs)}\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -869,7 +907,7 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -893,9 +931,7 @@ } ], "source": [ - "\n", - "\n", - "cm = getattr(plt.cm, \"cool\")\n", + "cm = getattr(plt.cm, \"cool\")\n", "# cs = cm(plt.Normalize()(vs_))\n", "cm" ] @@ -919,15 +955,20 @@ ], "source": [ "key = keysplit(key)\n", - "xs_, cs_, _ = uniformly_sample_from_mesh(key, 15_000, mesh.faces, mesh.vertices, \n", - " texture_uv=texvis.uv, uv_to_color=texvis.material.to_color)\n", + "xs_, cs_, _ = uniformly_sample_from_mesh(\n", + " key,\n", + " 15_000,\n", + " mesh.faces,\n", + " mesh.vertices,\n", + " texture_uv=texvis.uv,\n", + " uv_to_color=texvis.material.to_color,\n", + ")\n", "\n", "# labels_ = gm.predict(xs_)\n", "# cols_ = mean_colors[labels_]\n", "# cs_.shape, cols_.shape\n", "\n", "\n", - "\n", "# distances = np.linalg.norm(cs_ - cols_, axis=1)\n", "# print(distances.shape)\n", "# print(np.min(distances), np.max(distances))\n", @@ -953,24 +994,24 @@ } ], "source": [ - "\n", - "\n", "print(xs_.shape, cs_.shape)\n", "\n", "\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"Spheres\"})\n", "msg.payload.data.MergeFrom(\n", - " traceviz.client.to_pytree_msg({\n", - " 'centers': np.array(xs_), \n", - " 'colors': np.array(cs_), \n", - " \"scales\": 0.025*np.ones(len(xs_))\n", - " })\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": np.array(xs_),\n", + " \"colors\": np.array(cs_),\n", + " \"scales\": 0.025 * np.ones(len(xs_)),\n", + " }\n", + " )\n", ")\n", - " \n", + "\n", "\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { diff --git a/scripts/_mkl/notebooks/05b - Gaussian Approximation.ipynb b/scripts/_mkl/notebooks/05b - Gaussian Approximation.ipynb index dbd23965..945d77b9 100644 --- a/scripts/_mkl/notebooks/05b - Gaussian Approximation.ipynb +++ b/scripts/_mkl/notebooks/05b - Gaussian Approximation.ipynb @@ -99,14 +99,14 @@ "outputs": [], "source": [ "from bayes3d._mkl.trimesh_to_gaussians import (\n", - " patch_trimesh, \n", + " patch_trimesh,\n", " sample_from_mesh,\n", " barycentric_to_mesh as _barycentric_to_mesh,\n", - " uniformly_sample_from_mesh, \n", - " ellipsoid_embedding, \n", - " get_mean_colors, \n", + " uniformly_sample_from_mesh,\n", + " ellipsoid_embedding,\n", + " get_mean_colors,\n", " pack_transform,\n", - " transform_from_gaussian\n", + " transform_from_gaussian,\n", ")\n", "import trimesh\n", "import numpy as np\n", @@ -120,7 +120,7 @@ "# SEED\n", "key = jax.random.PRNGKey(0)\n", "\n", - "barycentric_to_mesh = vmap(_barycentric_to_mesh, (0,0,None))" + "barycentric_to_mesh = vmap(_barycentric_to_mesh, (0, 0, None))" ] }, { @@ -131,7 +131,7 @@ "source": [ "import traceviz.client\n", "import numpy as np\n", - "from traceviz.proto import viz_pb2\n", + "from traceviz.proto import viz_pb2\n", "import json\n", "import matplotlib.pyplot as plt" ] @@ -195,8 +195,9 @@ "metadata": {}, "outputs": [], "source": [ - "def fit(key, mesh, means_init, precisions_init, covariance_type=\"full\", iter=20, noise=0.0):\n", - " \n", + "def fit(\n", + " key, mesh, means_init, precisions_init, covariance_type=\"full\", iter=20, noise=0.0\n", + "):\n", " # SAMPLE FROM MESH\n", " # ----------------\n", " _, key = keysplit(key, 1, 1)\n", @@ -207,21 +208,24 @@ " # ----------\n", " key = keysplit(key)\n", " n_components = means_init.shape[0]\n", - " X = xs + np.random.randn(*xs.shape)*noise\n", + " X = xs + np.random.randn(*xs.shape) * noise\n", "\n", " # FIT THE GMM\n", " # -----------\n", - " gm = GaussianMixture(n_components=n_components, \n", - " tol=1e-3, max_iter=iter, \n", - " covariance_type=covariance_type, \n", - " means_init=means_init,\n", - " precisions_init=precisions_init).fit(X)\n", + " gm = GaussianMixture(\n", + " n_components=n_components,\n", + " tol=1e-3,\n", + " max_iter=iter,\n", + " covariance_type=covariance_type,\n", + " means_init=means_init,\n", + " precisions_init=precisions_init,\n", + " ).fit(X)\n", "\n", - " mus = gm.means_\n", + " mus = gm.means_\n", " if gm.covariance_type == \"spherical\":\n", - " covs = gm.covariances_[:,None,None]*jnp.eye(3)[None,:,:]\n", + " covs = gm.covariances_[:, None, None] * jnp.eye(3)[None, :, :]\n", " else:\n", - " covs = gm.covariances_\n", + " covs = gm.covariances_\n", " labels = gm.predict(X)\n", " mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)\n", "\n", @@ -256,25 +260,29 @@ "# ----------\n", "key = keysplit(key)\n", "n_components = 100\n", - "noise = 0.01; \n", - "X = xs + np.random.randn(*xs.shape)*noise\n", - "means_init = np.array(uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]);\n", - "\n", + "noise = 0.01\n", + "X = xs + np.random.randn(*xs.shape) * noise\n", + "means_init = np.array(\n", + " uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]\n", + ")\n", "# FIT THE GMM\n", "# -----------\n", - "gm = GaussianMixture(n_components=n_components, \n", - " tol=1e-3, max_iter=100, \n", - " covariance_type=\"spherical\", \n", - " means_init=means_init).fit(X)\n", + "gm = GaussianMixture(\n", + " n_components=n_components,\n", + " tol=1e-3,\n", + " max_iter=100,\n", + " covariance_type=\"spherical\",\n", + " means_init=means_init,\n", + ").fit(X)\n", "\n", "\n", - "mus = gm.means_\n", + "mus = gm.means_\n", "if gm.covariance_type == \"spherical\":\n", - " covs = gm.covariances_[:,None,None]*jnp.eye(3)[None,:,:]\n", + " covs = gm.covariances_[:, None, None] * jnp.eye(3)[None, :, :]\n", "else:\n", - " covs = gm.covariances_\n", - "labels = gm.predict(X)\n", - "transforms = vmap(transform_from_gaussian, (0,0,None))(mus, covs, 2.0)\n", + " covs = gm.covariances_\n", + "labels = gm.predict(X)\n", + "transforms = vmap(transform_from_gaussian, (0, 0, None))(mus, covs, 2.0)\n", "mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)\n", "\n", "print(f\"\"\"\n", @@ -305,16 +313,20 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"Gaussians2\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'transforms': np.array(transforms )[nums>0], \n", - " 'colors': np.array(mean_colors)[nums>0] \n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"transforms\": np.array(transforms)[nums > 0],\n", + " \"colors\": np.array(mean_colors)[nums > 0],\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -323,8 +335,9 @@ "metadata": {}, "outputs": [], "source": [ - "def fit(key, mesh, means_init, precisions_init, covariance_type=\"full\", iter=20, noise=0.0):\n", - " \n", + "def fit(\n", + " key, mesh, means_init, precisions_init, covariance_type=\"full\", iter=20, noise=0.0\n", + "):\n", " # SAMPLE FROM MESH\n", " # ----------------\n", " _, key = keysplit(key, 1, 1)\n", @@ -335,21 +348,24 @@ " # ----------\n", " key = keysplit(key)\n", " n_components = means_init.shape[0]\n", - " X = xs + np.random.randn(*xs.shape)*noise\n", + " X = xs + np.random.randn(*xs.shape) * noise\n", "\n", " # FIT THE GMM\n", " # -----------\n", - " gm = GaussianMixture(n_components=n_components, \n", - " tol=1e-3, max_iter=iter, \n", - " covariance_type=covariance_type, \n", - " means_init=means_init,\n", - " precisions_init=precisions_init).fit(X)\n", + " gm = GaussianMixture(\n", + " n_components=n_components,\n", + " tol=1e-3,\n", + " max_iter=iter,\n", + " covariance_type=covariance_type,\n", + " means_init=means_init,\n", + " precisions_init=precisions_init,\n", + " ).fit(X)\n", "\n", - " mus = gm.means_\n", + " mus = gm.means_\n", " if gm.covariance_type == \"spherical\":\n", - " covs = gm.covariances_[:,None,None]*jnp.eye(3)[None,:,:]\n", + " covs = gm.covariances_[:, None, None] * jnp.eye(3)[None, :, :]\n", " else:\n", - " covs = gm.covariances_\n", + " covs = gm.covariances_\n", " labels = gm.predict(X)\n", " mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)\n", "\n", @@ -441,12 +457,20 @@ "CVs = [covs]\n", "CLs = [mean_colors]\n", "\n", - "for t in range(1,5):\n", + "for t in range(1, 5):\n", " print(t, end=\"\\r\")\n", " mesh = load_mesh(t)\n", " key = keysplit(key)\n", "\n", - " mus, covs, mean_colors = fit(key, mesh, MUs[-1], vmap(jnp.linalg.inv)(CVs[-1])[:,0,0], covariance_type=\"spherical\", iter=10, noise=0.0)\n", + " mus, covs, mean_colors = fit(\n", + " key,\n", + " mesh,\n", + " MUs[-1],\n", + " vmap(jnp.linalg.inv)(CVs[-1])[:, 0, 0],\n", + " covariance_type=\"spherical\",\n", + " iter=10,\n", + " noise=0.0,\n", + " )\n", " MUs.append(mus)\n", " CVs.append(covs)\n", " CLs.append(mean_colors)" @@ -473,26 +497,27 @@ "mus = MUs[t]\n", "covs = CVs[t]\n", "mean_colors = CLs[t]\n", - "transforms = vmap(transform_from_gaussian, (0,0,None))(mus, covs, 3.0)\n", + "transforms = vmap(transform_from_gaussian, (0, 0, None))(mus, covs, 3.0)\n", "\n", - "colors = 0.4*jnp.ones_like(mean_colors)\n", - "colors = colors.at[:,3].set(1.)\n", + "colors = 0.4 * jnp.ones_like(mean_colors)\n", + "colors = colors.at[:, 3].set(1.0)\n", "\n", "\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"Gaussians2\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'transforms': np.array(transforms ), \n", - " 'colors': np.array(colors)\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\"transforms\": np.array(transforms), \"colors\": np.array(colors)}\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { diff --git a/scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb b/scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb index b6d008ae..85b1a610 100644 --- a/scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb +++ b/scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb @@ -16,7 +16,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp gaussian_renderer" + "# |default_exp gaussian_renderer" ] }, { @@ -40,7 +40,7 @@ } ], "source": [ - "#|export\n", + "# |export\n", "import bayes3d as b3d\n", "import trimesh\n", "import os\n", @@ -52,12 +52,16 @@ "import jax.numpy as jnp\n", "from functools import partial\n", "from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics\n", - "from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth\n", + "from bayes3d.transforms_3d import (\n", + " transform_from_pos_target_up,\n", + " add_homogenous_ones,\n", + " unproject_depth,\n", + ")\n", "import tensorflow_probability as tfp\n", "from tensorflow_probability.substrates.jax.math import lambertw\n", "\n", - "normal_cdf = jax.scipy.stats.norm.cdf\n", - "normal_pdf = jax.scipy.stats.norm.pdf\n", + "normal_cdf = jax.scipy.stats.norm.cdf\n", + "normal_pdf = jax.scipy.stats.norm.pdf\n", "normal_logpdf = jax.scipy.stats.norm.logpdf\n", "inv = jnp.linalg.inv\n", "\n", @@ -71,7 +75,7 @@ "outputs": [], "source": [ "import traceviz.client\n", - "from traceviz.proto import viz_pb2\n", + "from traceviz.proto import viz_pb2\n", "import json\n", "import inspect\n", "from IPython.display import Markdown as md\n", @@ -84,7 +88,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from bayes3d._mkl.types import *" ] }, @@ -94,8 +98,8 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def ellipsoid_embedding(cov:CovarianceMatrix) -> Matrix:\n", + "# |export\n", + "def ellipsoid_embedding(cov: CovarianceMatrix) -> Matrix:\n", " \"\"\"Returns A with cov = A@A.T\"\"\"\n", " sigma, U = jnp.linalg.eigh(cov)\n", " D = jnp.diag(jnp.sqrt(sigma))\n", @@ -108,8 +112,8 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def bilinear(x:Array, y:Array, A:Matrix) -> Float:\n", + "# |export\n", + "def bilinear(x: Array, y: Array, A: Matrix) -> Float:\n", " return x.T @ A @ y" ] }, @@ -119,21 +123,21 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def log_gaussian(x:Vector, mu:Vector, P:PrecisionMatrix) -> Float:\n", + "# |export\n", + "def log_gaussian(x: Vector, mu: Vector, P: PrecisionMatrix) -> Float:\n", " \"\"\"Evaluate an **unnormalized** gaussian at a given point.\"\"\"\n", - " return -0.5 * bilinear(x-mu, x-mu, P)\n", + " return -0.5 * bilinear(x - mu, x - mu, P)\n", "\n", "\n", - "def gaussian(x:Vector, mu:Vector, P:PrecisionMatrix) -> Float:\n", + "def gaussian(x: Vector, mu: Vector, P: PrecisionMatrix) -> Float:\n", " \"\"\"Evaluate an **unnormalized** gaussian at a given point.\"\"\"\n", - " return jnp.exp(-0.5 * bilinear(x-mu, x-mu, P))\n", + " return jnp.exp(-0.5 * bilinear(x - mu, x - mu, P))\n", "\n", "\n", - "def gaussian_normalizing_constant(P:PrecisionMatrix) -> Float:\n", + "def gaussian_normalizing_constant(P: PrecisionMatrix) -> Float:\n", " \"\"\"Returns the normalizing constant of an unnormalized gaussian.\"\"\"\n", " n = P.shape[0]\n", - " return jnp.sqrt(jnp.linalg.det(P)/(2*jnp.pi)**n)" + " return jnp.sqrt(jnp.linalg.det(P) / (2 * jnp.pi) ** n)" ] }, { @@ -142,17 +146,19 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def gaussian_restriction_to_ray(loc:Vector, P:PrecisionMatrix, A:CholeskyMatrix, x:Vector, v:Direction):\n", + "# |export\n", + "def gaussian_restriction_to_ray(\n", + " loc: Vector, P: PrecisionMatrix, A: CholeskyMatrix, x: Vector, v: Direction\n", + "):\n", " \"\"\"\n", - " Restricts a gaussian to a ray and returns \n", - " the mean `mu` and standard deviation `std`, s.t. we have \n", + " Restricts a gaussian to a ray and returns\n", + " the mean `mu` and standard deviation `std`, s.t. we have\n", " $$\n", " P(x + t*v | loc, cov) = P(x + mu*v | loc, cov) * N(t | mu, std)\n", " $$\n", " \"\"\"\n", - " mu = bilinear(loc - x, v, P)/bilinear(v, v, P)\n", - " std = 1/jnp.linalg.norm(inv(A)@v)\n", + " mu = bilinear(loc - x, v, P) / bilinear(v, v, P)\n", + " std = 1 / jnp.linalg.norm(inv(A) @ v)\n", " return mu, std" ] }, @@ -169,17 +175,17 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def discrete_arrival_probabilities(occupancy_probs:Vector):\n", + "# |export\n", + "def discrete_arrival_probabilities(occupancy_probs: Vector):\n", " \"\"\"\n", - " Given an vector of `n` occupancy probabilities of neighbouring pixels, \n", - " it returns a vector of length `n+1` containing the probabilities of stopping \n", + " Given an vector of `n` occupancy probabilities of neighbouring pixels,\n", + " it returns a vector of length `n+1` containing the probabilities of stopping\n", " at a each pixel (while traversing them left to right) or not stopping at all.\n", "\n", " The return array is given by:\n", " $$\n", " q_i = p_i \\cdot \\prod_{j=0}^{i-1} (1 - p_j)\n", - " \n", + "\n", " $$\n", " for $i=0,...,n-1$, and\n", " $$\n", @@ -192,7 +198,9 @@ " X(T) = \\sigma(T)*\\exp(- \\int_0^T \\sigma(t) \\ dt).\n", " $$\n", " \"\"\"\n", - " transmittances = jnp.concatenate([jnp.array([1.0]), jnp.cumprod(1-occupancy_probs)])\n", + " transmittances = jnp.concatenate(\n", + " [jnp.array([1.0]), jnp.cumprod(1 - occupancy_probs)]\n", + " )\n", " extended_occupancies = jnp.concatenate([occupancy_probs, jnp.array([1.0])])\n", " return extended_occupancies * transmittances" ] @@ -215,15 +223,15 @@ ], "source": [ "key = keysplit(key)\n", - "occupancy_probs = 0.3*jax.random.uniform(key, (10,))\n", - "arrival_probs = discrete_arrival_probabilities(occupancy_probs)\n", + "occupancy_probs = 0.3 * jax.random.uniform(key, (10,))\n", + "arrival_probs = discrete_arrival_probabilities(occupancy_probs)\n", "\n", "assert jnp.isclose(arrival_probs.sum(), 1.0)\n", "\n", "# =======================\n", - "plt.figure(figsize=(3,1))\n", + "plt.figure(figsize=(3, 1))\n", "plt.plot(jnp.arange(len(occupancy_probs)), occupancy_probs)\n", - "plt.plot(jnp.arange(len(occupancy_probs)+1), arrival_probs);" + "plt.plot(jnp.arange(len(occupancy_probs) + 1), arrival_probs);" ] }, { @@ -232,7 +240,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def gaussian_time_of_arrival(xs, mu, sig, w=1.0):\n", " \"\"\"\n", " Time of first arrival for a **single** weighted 1-dimensional Gaussian, i.e. returns an array of\n", @@ -241,23 +249,28 @@ " Y(T) = w*g(T | \\mu, \\sigma)*\\exp(- \\int_0^T w*g(t | \\mu, \\sigma) \\ dt).\n", " $$\n", " \"\"\"\n", - " ys = w*normal_pdf(xs, loc=mu, scale=sig) * jnp.exp(\n", - " - w*normal_cdf(xs, loc=mu, scale=sig) \n", - " + w*normal_cdf(0.0, loc=mu, scale=sig))\n", - " return ys \n", - "\n", - "\n", - "def gaussian_most_likely_time_of_arrival(mu, sig, w=1.):\n", + " ys = (\n", + " w\n", + " * normal_pdf(xs, loc=mu, scale=sig)\n", + " * jnp.exp(\n", + " -w * normal_cdf(xs, loc=mu, scale=sig)\n", + " + w * normal_cdf(0.0, loc=mu, scale=sig)\n", + " )\n", + " )\n", + " return ys\n", + "\n", + "\n", + "def gaussian_most_likely_time_of_arrival(mu, sig, w=1.0):\n", " \"\"\"\n", " Returns the most likely time of first arrival\n", - " for a single weighted 1-dimensional Gaussian, i.e. the argmax of \n", + " for a single weighted 1-dimensional Gaussian, i.e. the argmax of\n", " $$\n", " Y(T) = w*g(T | \\mu, \\sigma)*\\exp(- \\int_0^T w*g(t | \\mu, \\sigma) \\ dt).\n", " $$\n", " \"\"\"\n", " # TODO: Check if this is correct, cf. my notes.\n", - " Z = jnp.sqrt(lambertw(1/(2*jnp.pi) * w**2))\n", - " return mu - Z*sig" + " Z = jnp.sqrt(lambertw(1 / (2 * jnp.pi) * w**2))\n", + " return mu - Z * sig" ] }, { @@ -289,13 +302,13 @@ } ], "source": [ - "key, keys = keysplit(key,1,3)\n", + "key, keys = keysplit(key, 1, 3)\n", "\n", - "xs = jnp.linspace(0,11,20_000)\n", - "mu = jax.random.uniform(keys[0],(), float, 3.0,8.0)\n", - "sig = jax.random.uniform(keys[1],(), float, 0.1,2.0)\n", - "w = jax.random.uniform(keys[2],(), float, 0.1,1.0)\n", - "t = gaussian_most_likely_time_of_arrival(mu, sig, w=w)\n", + "xs = jnp.linspace(0, 11, 20_000)\n", + "mu = jax.random.uniform(keys[0], (), float, 3.0, 8.0)\n", + "sig = jax.random.uniform(keys[1], (), float, 0.1, 2.0)\n", + "w = jax.random.uniform(keys[2], (), float, 0.1, 1.0)\n", + "t = gaussian_most_likely_time_of_arrival(mu, sig, w=w)\n", "\n", "print(f\"\"\"\n", " mu: {mu:0.3f}\n", @@ -305,16 +318,28 @@ "\"\"\")\n", "\n", "# =======================\n", - "plt.figure(figsize=(10,2))\n", - "plt.gca().spines['right'].set_visible(False)\n", - "plt.gca().spines['top'].set_visible(False)\n", - "plt.plot(xs, gaussian_time_of_arrival(xs, mu, sig, w), c=\"C1\", alpha=1., label=\"Time of first Arrival\")\n", + "plt.figure(figsize=(10, 2))\n", + "plt.gca().spines[\"right\"].set_visible(False)\n", + "plt.gca().spines[\"top\"].set_visible(False)\n", + "plt.plot(\n", + " xs,\n", + " gaussian_time_of_arrival(xs, mu, sig, w),\n", + " c=\"C1\",\n", + " alpha=1.0,\n", + " label=\"Time of first Arrival\",\n", + ")\n", "plt.fill_between(xs, gaussian_time_of_arrival(xs, mu, sig, w), color=\"C1\", alpha=0.1)\n", - "plt.vlines(t, 0, gaussian_time_of_arrival(t, mu, sig, w), color=\"C1\", alpha=.5)\n", + "plt.vlines(t, 0, gaussian_time_of_arrival(t, mu, sig, w), color=\"C1\", alpha=0.5)\n", "plt.scatter(t, gaussian_time_of_arrival(t, mu, sig, w), c=\"C1\")\n", - "plt.plot(xs, w*normal_pdf(xs, loc=mu, scale=sig), c=\"C0\", alpha=1., label=\"Weighted Gaussian\")\n", - "plt.vlines(mu, 0, w*normal_pdf(mu, loc=mu, scale=sig), color=\"C0\", alpha=.5)\n", - "plt.scatter(mu, w*normal_pdf(mu, loc=mu, scale=sig), c=\"C0\")\n", + "plt.plot(\n", + " xs,\n", + " w * normal_pdf(xs, loc=mu, scale=sig),\n", + " c=\"C0\",\n", + " alpha=1.0,\n", + " label=\"Weighted Gaussian\",\n", + ")\n", + "plt.vlines(mu, 0, w * normal_pdf(mu, loc=mu, scale=sig), color=\"C0\", alpha=0.5)\n", + "plt.scatter(mu, w * normal_pdf(mu, loc=mu, scale=sig), c=\"C0\")\n", "plt.legend();" ] }, @@ -331,15 +356,17 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def weighted_arrival_intersection(mu:Vector, P:PrecisionMatrix, A:CholeskyMatrix, w:Float, x:Vector, v:Direction):\n", + "# |export\n", + "def weighted_arrival_intersection(\n", + " mu: Vector, P: PrecisionMatrix, A: CholeskyMatrix, w: Float, x: Vector, v: Direction\n", + "):\n", " \"\"\"\n", " Returns the \"intersection\" of a ray with a gaussian which we define as\n", " the mode of the gaussian restricted to the ray.\n", " \"\"\"\n", " t0, sig0 = gaussian_restriction_to_ray(mu, P, A, x, v)\n", - " w0 = w*gaussian(t0*v, mu, P)\n", - " Z = w0/gaussian_normalizing_constant(P)\n", + " w0 = w * gaussian(t0 * v, mu, P)\n", + " Z = w0 / gaussian_normalizing_constant(P)\n", " t = gaussian_most_likely_time_of_arrival(t0, sig0, Z)\n", " return t, w0" ] @@ -350,24 +377,26 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def argmax_intersection(mu:Vector, P:PrecisionMatrix, x:Vector, v:Direction):\n", + "# |export\n", + "def argmax_intersection(mu: Vector, P: PrecisionMatrix, x: Vector, v: Direction):\n", " \"\"\"\n", " Returns the \"intersection\" of a ray with a gaussian which we define as\n", " the mode of the gaussian restricted to the ray.\n", " \"\"\"\n", - " t = bilinear(mu - x, v, P)/bilinear(v, v, P)\n", + " t = bilinear(mu - x, v, P) / bilinear(v, v, P)\n", " return t\n", "\n", "\n", - "#|export\n", - "def weighted_argmax_intersection(mu:Vector, P:PrecisionMatrix, w:Float, x:Vector, v:Direction):\n", + "# |export\n", + "def weighted_argmax_intersection(\n", + " mu: Vector, P: PrecisionMatrix, w: Float, x: Vector, v: Direction\n", + "):\n", " \"\"\"\n", " Returns the \"intersection\" of a ray with a gaussian which we define as\n", " the mode of the gaussian restricted to the ray.\n", " \"\"\"\n", - " t = bilinear(mu - x, v, P)/bilinear(v, v, P)\n", - " return t, w*gaussian(x + t*v, mu, P)" + " t = bilinear(mu - x, v, P) / bilinear(v, v, P)\n", + " return t, w * gaussian(x + t * v, mu, P)" ] }, { @@ -388,40 +417,47 @@ ], "source": [ "# Define Gaussian\n", - "A = jnp.array([[5,1],[0,1]]).T \n", - "loc = jnp.array([0.,10.])\n", - "w = 1.\n", - "Cov = A@A.T \n", - "P = inv(Cov)\n", + "A = jnp.array([[5, 1], [0, 1]]).T\n", + "loc = jnp.array([0.0, 10.0])\n", + "w = 1.0\n", + "Cov = A @ A.T\n", + "P = inv(Cov)\n", "\n", "# Gaussian samples\n", "key = keysplit(key)\n", - "xs = jax.random.normal(key, shape=(2_000,2)) \n", - "ys = xs@A.T + loc\n", + "xs = jax.random.normal(key, shape=(2_000, 2))\n", + "ys = xs @ A.T + loc\n", "\n", "# Ray directions\n", "n = 100\n", - "ths = jnp.linspace(0, jnp.pi, n); # \"Thetas\"\n", - "vs = jnp.stack([jnp.cos(ths), jnp.sin(ths)], axis=1)\n", + "ths = jnp.linspace(0, jnp.pi, n) # \"Thetas\"\n", + "vs = jnp.stack([jnp.cos(ths), jnp.sin(ths)], axis=1)\n", "\n", "# Intersection points\n", "# Try both verions:\n", "# ts, ps = vmap(weighted_arrival_intersection, (None,None,None,None,None,0))(loc, P, A, w, jnp.zeros(2), vs)\n", - "ts, ps = vmap(weighted_argmax_intersection, (None,None,None,None,0))(loc, P, w, jnp.zeros(2), vs)\n", - "zs = ts[:,None]*vs \n", + "ts, ps = vmap(weighted_argmax_intersection, (None, None, None, None, 0))(\n", + " loc, P, w, jnp.zeros(2), vs\n", + ")\n", + "zs = ts[:, None] * vs\n", "valid = ts > 0\n", "\n", "\n", - "\n", - "# ============================ \n", - "fig, ax = plt.subplots(1,1, figsize=(10,4))\n", + "# ============================\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 4))\n", "axs = [ax]\n", - "axs[0].axis('off')\n", - "axs[0].set_aspect('equal')\n", - "axs[0].scatter(*ys[:].T, s=1., alpha=1., c=\"k\", marker=\".\")\n", - "axs[0].scatter(*zs[valid].T, s=20, c=ps[valid], cmap=\"plasma\", marker=\"s\",alpha=1)\n", - "plot_segs(jnp.concatenate([jnp.zeros((n,2)), zs], axis=-1)[valid], ax=axs[0], alpha=0.1, c=\"C0\", zorder=0)\n", - "axs[0].scatter(0,0,c=\"C0\", marker=\"^\",s=100, edgecolor=\"w\", zorder=1);" + "axs[0].axis(\"off\")\n", + "axs[0].set_aspect(\"equal\")\n", + "axs[0].scatter(*ys[:].T, s=1.0, alpha=1.0, c=\"k\", marker=\".\")\n", + "axs[0].scatter(*zs[valid].T, s=20, c=ps[valid], cmap=\"plasma\", marker=\"s\", alpha=1)\n", + "plot_segs(\n", + " jnp.concatenate([jnp.zeros((n, 2)), zs], axis=-1)[valid],\n", + " ax=axs[0],\n", + " alpha=0.1,\n", + " c=\"C0\",\n", + " zorder=0,\n", + ")\n", + "axs[0].scatter(0, 0, c=\"C0\", marker=\"^\", s=100, edgecolor=\"w\", zorder=1);" ] }, { @@ -476,12 +512,12 @@ } ], "source": [ - "A1 = jnp.array([[5,1],[0,1]]).T \n", - "loc1 = jnp.array([0.,12.])\n", - "w1= 1.\n", - "A2 = jnp.array([[4,-1],[0,1]]).T \n", - "loc2 = jnp.array([5.,20.])\n", - "w2= 1.\n", + "A1 = jnp.array([[5, 1], [0, 1]]).T\n", + "loc1 = jnp.array([0.0, 12.0])\n", + "w1 = 1.0\n", + "A2 = jnp.array([[4, -1], [0, 1]]).T\n", + "loc2 = jnp.array([5.0, 20.0])\n", + "w2 = 1.0\n", "\n", "ps = []\n", "ts = []\n", @@ -490,18 +526,20 @@ "n = 100\n", "ths = jnp.linspace(0, jnp.pi, n)\n", "vs = jnp.stack([jnp.cos(ths), jnp.sin(ths)], axis=1)\n", - "for A, loc,w in [(A1,loc1,w1), (A2,loc2,w2)]: \n", + "for A, loc, w in [(A1, loc1, w1), (A2, loc2, w2)]:\n", " key = keysplit(key)\n", - " Cov = A@A.T \n", - " P = inv(Cov)\n", + " Cov = A @ A.T\n", + " P = inv(Cov)\n", "\n", - " xs = jax.random.normal(key, shape=(5_000,2)) \n", - " ys_ = xs@A.T + loc\n", - " cs_ = vmap(lambda y: w*gaussian(y, loc, P))(ys_)\n", + " xs = jax.random.normal(key, shape=(5_000, 2))\n", + " ys_ = xs @ A.T + loc\n", + " cs_ = vmap(lambda y: w * gaussian(y, loc, P))(ys_)\n", " order = jnp.argsort(cs_)\n", " ys.append(ys_[order])\n", " cs.append(cs_[order])\n", - " ts_, ps_ = vmap(weighted_arrival_intersection, (None,None,None,None,None,0))(loc, P, A, w, jnp.zeros(2), vs)\n", + " ts_, ps_ = vmap(weighted_arrival_intersection, (None, None, None, None, None, 0))(\n", + " loc, P, A, w, jnp.zeros(2), vs\n", + " )\n", " # ts_, ps_ = vmap(weighted_argmax_intersection, (None,None,None,None,0))(loc, P, w, jnp.zeros(2), vs)\n", " ts.append(ts_)\n", " ps.append(ps_)\n", @@ -510,47 +548,72 @@ "ts = jnp.array(ts)\n", "ys = jnp.array(ys)\n", "cs = jnp.array(cs)\n", - "zs = ts[:,:,None]*vs[None]\n", + "zs = ts[:, :, None] * vs[None]\n", "qs = vmap(discrete_arrival_probabilities)(ps.T)\n", "valid = ts > 0\n", "print(ps.shape, qs.shape)\n", - "print(ps[:,0])\n", + "print(ps[:, 0])\n", "print(qs[0])\n", "print(valid.shape)\n", "\n", - "zmax=30\n", - "ds = qs.T[0]*ts[0] + qs.T[1]*ts[1] + qs.T[2]*zmax\n", - "ds = ds[:,None]*vs\n", + "zmax = 30\n", + "ds = qs.T[0] * ts[0] + qs.T[1] * ts[1] + qs.T[2] * zmax\n", + "ds = ds[:, None] * vs\n", "\n", - "all_valid =valid[1]*valid[0]\n", + "all_valid = valid[1] * valid[0]\n", "\n", "ys_all = jnp.concatenate(ys)\n", "cs_all = jnp.concatenate(cs)\n", "order_all = jnp.argsort(cs_all)\n", - "ys_all = ys_all[order_all] \n", + "ys_all = ys_all[order_all]\n", "cs_all = cs_all[order_all]\n", "\n", "# =========================================\n", "s_inter = 10\n", "\n", - "fig, ax = plt.subplots(1,1, figsize=(10,4))\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 4))\n", "axs = [ax]\n", - "axs[0].set_xlim(-30,30)\n", + "axs[0].set_xlim(-30, 30)\n", "axs[0].set_ylim(0, 40)\n", - "axs[0].axis('off')\n", - "axs[0].set_aspect('equal')\n", - "axs[0].scatter(0,0,c=\"C0\", marker=\"^\",s=100, edgecolor=\"w\", zorder=1)\n", + "axs[0].axis(\"off\")\n", + "axs[0].set_aspect(\"equal\")\n", + "axs[0].scatter(0, 0, c=\"C0\", marker=\"^\", s=100, edgecolor=\"w\", zorder=1)\n", "# axs[0].scatter(*ys[0].T, s=10, c=np.clip(1-w1,0.3,0.7)*np.array([[1,1,1]]), edgecolor=np.clip(1-10*w1,0,1)*np.array([[1,1,1]]), linewidth=0.1, marker=\"o\")\n", "# axs[0].scatter(*ys[1].T, s=10, c=np.clip(1-w2,0.3,0.7)*np.array([[1,1,1]]), edgecolor=np.clip(1-10*w2,0,1)*np.array([[1,1,1]]), linewidth=0.1, marker=\"o\")\n", "# axs[0].scatter(*ys[0].T, s=30, c=cs[0], vmin=0, vmax=1, cmap=\"binary\", marker=\"o\")\n", "# axs[0].scatter(*ys[1].T, s=30, c=cs[1], vmin=0, vmax=1, cmap=\"binary\", marker=\"o\")\n", "axs[0].scatter(*ys_all.T, s=30, c=cs_all, vmin=0, vmax=1, cmap=\"binary\", marker=\"o\")\n", - "axs[0].scatter(*zs[0,all_valid].T, s=s_inter, c=qs[all_valid,0], cmap=\"viridis\", vmin=qs.min(), vmax=qs.max(), marker=\"o\",alpha=1)\n", - "axs[0].scatter(*zs[1,all_valid].T, s=s_inter, c=qs[all_valid,1], cmap=\"viridis\", vmin=qs.min(), vmax=qs.max(), marker=\"o\",alpha=1);\n", - "\n", + "axs[0].scatter(\n", + " *zs[0, all_valid].T,\n", + " s=s_inter,\n", + " c=qs[all_valid, 0],\n", + " cmap=\"viridis\",\n", + " vmin=qs.min(),\n", + " vmax=qs.max(),\n", + " marker=\"o\",\n", + " alpha=1,\n", + ")\n", + "axs[0].scatter(\n", + " *zs[1, all_valid].T,\n", + " s=s_inter,\n", + " c=qs[all_valid, 1],\n", + " cmap=\"viridis\",\n", + " vmin=qs.min(),\n", + " vmax=qs.max(),\n", + " marker=\"o\",\n", + " alpha=1,\n", + ")\n", "# axs[0].scatter(*ds[all_valid].T, s=20, c=qs[valid[1]*valid[0],2], vmin=qs.min(), vmax=qs.max(), marker=\"s\",alpha=1);\n", - "axs[0].scatter(*(zmax * vs)[all_valid].T, s=s_inter, c=qs[valid[1]*valid[0],2], vmin=qs.min(), vmax=qs.max(), marker=\"o\", alpha=1);\n", - "axs[0].plot(*ds[all_valid].T, marker=\".\", c=\"C0\");\n" + "axs[0].scatter(\n", + " *(zmax * vs)[all_valid].T,\n", + " s=s_inter,\n", + " c=qs[valid[1] * valid[0], 2],\n", + " vmin=qs.min(),\n", + " vmax=qs.max(),\n", + " marker=\"o\",\n", + " alpha=1,\n", + ")\n", + "axs[0].plot(*ds[all_valid].T, marker=\".\", c=\"C0\");" ] }, { @@ -561,21 +624,34 @@ "source": [ "%matplotlib inline\n", "from timeit import timeit\n", - "from ipywidgets import (interact, interactive, IntSlider, FloatSlider, HTMLMath, HTML,\n", - " FloatRangeSlider, RadioButtons, Checkbox, Dropdown, Button, VBox, HBox, Output)\n", + "from ipywidgets import (\n", + " interact,\n", + " interactive,\n", + " IntSlider,\n", + " FloatSlider,\n", + " HTMLMath,\n", + " HTML,\n", + " FloatRangeSlider,\n", + " RadioButtons,\n", + " Checkbox,\n", + " Dropdown,\n", + " Button,\n", + " VBox,\n", + " HBox,\n", + " Output,\n", + ")\n", "import warnings\n", - "warnings.filterwarnings('ignore')\n", "\n", + "warnings.filterwarnings(\"ignore\")\n", "\n", "\n", "def func(w1, w2):\n", - " global key;\n", - " A1 = jnp.array([[5,1],[0,1]]).T \n", - " loc1 = jnp.array([0.,12.])\n", - "\n", - " A2 = jnp.array([[4,-1],[0,1]]).T \n", - " loc2 = jnp.array([5.,20.])\n", + " global key\n", + " A1 = jnp.array([[5, 1], [0, 1]]).T\n", + " loc1 = jnp.array([0.0, 12.0])\n", "\n", + " A2 = jnp.array([[4, -1], [0, 1]]).T\n", + " loc2 = jnp.array([5.0, 20.0])\n", "\n", " ps = []\n", " ts = []\n", @@ -584,18 +660,20 @@ " n = 100\n", " ths = jnp.linspace(0, jnp.pi, n)\n", " vs = jnp.stack([jnp.cos(ths), jnp.sin(ths)], axis=1)\n", - " for A, loc,w in [(A1,loc1,w1), (A2,loc2,w2)]: \n", + " for A, loc, w in [(A1, loc1, w1), (A2, loc2, w2)]:\n", " key = keysplit(key)\n", - " Cov = A@A.T \n", - " P = inv(Cov)\n", + " Cov = A @ A.T\n", + " P = inv(Cov)\n", "\n", - " xs = jax.random.normal(key, shape=(5_000,2)) \n", - " ys_ = xs@A.T + loc\n", - " cs_ = vmap(lambda y: w*gaussian(y, loc, P))(ys_)\n", + " xs = jax.random.normal(key, shape=(5_000, 2))\n", + " ys_ = xs @ A.T + loc\n", + " cs_ = vmap(lambda y: w * gaussian(y, loc, P))(ys_)\n", " order = jnp.argsort(cs_)\n", " ys.append(ys_[order])\n", " cs.append(cs_[order])\n", - " ts_, ps_ = vmap(arrival_intersection, (None,None,None,None,None,0))(loc, P, A, w, jnp.zeros(2), vs)\n", + " ts_, ps_ = vmap(arrival_intersection, (None, None, None, None, None, 0))(\n", + " loc, P, A, w, jnp.zeros(2), vs\n", + " )\n", " ts.append(ts_)\n", " ps.append(ps_)\n", "\n", @@ -603,50 +681,88 @@ " ts = jnp.array(ts)\n", " ys = jnp.array(ys)\n", " cs = jnp.array(cs)\n", - " zs = ts[:,:,None]*vs[None]\n", + " zs = ts[:, :, None] * vs[None]\n", " qs = vmap(discrete_arrival_probabilities)(ps.T)\n", " valid = ts > 0\n", "\n", + " zmax = 30\n", + " ds = qs.T[0] * ts[0] + qs.T[1] * ts[1] + qs.T[2] * zmax\n", + " ds = ds[:, None] * vs\n", "\n", - " zmax=30\n", - " ds = qs.T[0]*ts[0] + qs.T[1]*ts[1] + qs.T[2]*zmax\n", - " ds = ds[:,None]*vs\n", - "\n", - " all_valid =valid[1]*valid[0]\n", + " all_valid = valid[1] * valid[0]\n", "\n", " ys_all = jnp.concatenate(ys)\n", " cs_all = jnp.concatenate(cs)\n", " order_all = jnp.argsort(cs_all)\n", - " ys_all = ys_all[order_all] \n", + " ys_all = ys_all[order_all]\n", " cs_all = cs_all[order_all]\n", "\n", " # =========================================\n", " s_inter = 10\n", "\n", - " fig, ax = plt.subplots(1,1, figsize=(10,4))\n", + " fig, ax = plt.subplots(1, 1, figsize=(10, 4))\n", " axs = [ax]\n", - " axs[0].set_xlim(-30,30)\n", + " axs[0].set_xlim(-30, 30)\n", " axs[0].set_ylim(0, 40)\n", - " axs[0].axis('off')\n", - " axs[0].set_aspect('equal')\n", - " axs[0].scatter(0,0,c=\"C0\", marker=\"^\",s=100, edgecolor=\"w\", zorder=1)\n", + " axs[0].axis(\"off\")\n", + " axs[0].set_aspect(\"equal\")\n", + " axs[0].scatter(0, 0, c=\"C0\", marker=\"^\", s=100, edgecolor=\"w\", zorder=1)\n", " # axs[0].scatter(*ys[0].T, s=10, c=np.clip(1-w1,0.3,0.7)*np.array([[1,1,1]]), edgecolor=np.clip(1-10*w1,0,1)*np.array([[1,1,1]]), linewidth=0.1, marker=\"o\")\n", " # axs[0].scatter(*ys[1].T, s=10, c=np.clip(1-w2,0.3,0.7)*np.array([[1,1,1]]), edgecolor=np.clip(1-10*w2,0,1)*np.array([[1,1,1]]), linewidth=0.1, marker=\"o\")\n", " # axs[0].scatter(*ys[0].T, s=30, c=cs[0], vmin=0, vmax=1, cmap=\"binary\", marker=\"o\")\n", " # axs[0].scatter(*ys[1].T, s=30, c=cs[1], vmin=0, vmax=1, cmap=\"binary\", marker=\"o\")\n", " axs[0].scatter(*ys_all.T, s=30, c=cs_all, vmin=0, vmax=1, cmap=\"binary\", marker=\"o\")\n", - " axs[0].scatter(*zs[0,all_valid].T, s=s_inter, c=qs[all_valid,0], cmap=\"viridis\", vmin=qs.min(), vmax=qs.max(), marker=\"o\",alpha=1)\n", - " axs[0].scatter(*zs[1,all_valid].T, s=s_inter, c=qs[all_valid,1], cmap=\"viridis\", vmin=qs.min(), vmax=qs.max(), marker=\"o\",alpha=1);\n", - "\n", + " axs[0].scatter(\n", + " *zs[0, all_valid].T,\n", + " s=s_inter,\n", + " c=qs[all_valid, 0],\n", + " cmap=\"viridis\",\n", + " vmin=qs.min(),\n", + " vmax=qs.max(),\n", + " marker=\"o\",\n", + " alpha=1,\n", + " )\n", + " axs[0].scatter(\n", + " *zs[1, all_valid].T,\n", + " s=s_inter,\n", + " c=qs[all_valid, 1],\n", + " cmap=\"viridis\",\n", + " vmin=qs.min(),\n", + " vmax=qs.max(),\n", + " marker=\"o\",\n", + " alpha=1,\n", + " )\n", " # axs[0].scatter(*ds[all_valid].T, s=20, c=qs[valid[1]*valid[0],2], vmin=qs.min(), vmax=qs.max(), marker=\"s\",alpha=1);\n", - " axs[0].scatter(*(zmax * vs)[all_valid].T, s=s_inter, c=qs[valid[1]*valid[0],2], vmin=qs.min(), vmax=qs.max(), marker=\"o\", alpha=1);\n", - "\n", - " axs[0].plot(*ds[all_valid].T, c=\"C0\", linestyle=\"-\");\n", - "\n", - "\n", - "widget = interactive(func,\n", - " w1 = FloatSlider(min=0.0, max=1., step=0.01, value=1.0, continuous_update=False, description='w1'),\n", - " w2 = FloatSlider(min=0.0, max=1., step=0.01, value=1.0, continuous_update=False, description='w2')\n", + " axs[0].scatter(\n", + " *(zmax * vs)[all_valid].T,\n", + " s=s_inter,\n", + " c=qs[valid[1] * valid[0], 2],\n", + " vmin=qs.min(),\n", + " vmax=qs.max(),\n", + " marker=\"o\",\n", + " alpha=1,\n", + " )\n", + " axs[0].plot(*ds[all_valid].T, c=\"C0\", linestyle=\"-\")\n", + "\n", + "\n", + "widget = interactive(\n", + " func,\n", + " w1=FloatSlider(\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " value=1.0,\n", + " continuous_update=False,\n", + " description=\"w1\",\n", + " ),\n", + " w2=FloatSlider(\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " value=1.0,\n", + " continuous_update=False,\n", + " description=\"w2\",\n", + " ),\n", ")\n", "widget" ] @@ -664,9 +780,9 @@ "metadata": {}, "outputs": [], "source": [ - "THRESH_99 = gaussian(jnp.array([4,0,0]), jnp.zeros(3), jnp.eye(3))\n", - "THRESH_97 = gaussian(jnp.array([3,0,0]), jnp.zeros(3), jnp.eye(3))\n", - "THRESH_73 = gaussian(jnp.array([2,0,0]), jnp.zeros(3), jnp.eye(3))" + "THRESH_99 = gaussian(jnp.array([4, 0, 0]), jnp.zeros(3), jnp.eye(3))\n", + "THRESH_97 = gaussian(jnp.array([3, 0, 0]), jnp.zeros(3), jnp.eye(3))\n", + "THRESH_73 = gaussian(jnp.array([2, 0, 0]), jnp.zeros(3), jnp.eye(3))" ] }, { @@ -675,35 +791,45 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def _cast_ray(v, mus, precisions, colors, weights, zmax=2.0, bg_color=jnp.array([1.,1.,1.,1.])):\n", + "# |export\n", + "def _cast_ray(\n", + " v,\n", + " mus,\n", + " precisions,\n", + " colors,\n", + " weights,\n", + " zmax=2.0,\n", + " bg_color=jnp.array([1.0, 1.0, 1.0, 1.0]),\n", + "):\n", " # TODO: Deal with negative intersections behind the camera\n", " # TODO: Maybe switch to log probs?\n", "\n", - " # Compute fuzzy intersections `xs` with Gaussians and \n", + " # Compute fuzzy intersections `xs` with Gaussians and\n", " # their function values `sigmas`\n", - " ts, sigmas = vmap(weighted_argmax_intersection, (0,0,0,None,None))(\n", - " mus, precisions, weights, jnp.zeros(3), v)\n", - " order = jnp.argsort(ts)\n", - " ts = ts[order]\n", + " ts, sigmas = vmap(weighted_argmax_intersection, (0, 0, 0, None, None))(\n", + " mus, precisions, weights, jnp.zeros(3), v\n", + " )\n", + " order = jnp.argsort(ts)\n", + " ts = ts[order]\n", " sigmas = sigmas[order]\n", - " xs = ts[:,None]*v[None,:]\n", + " xs = ts[:, None] * v[None, :]\n", "\n", " # TODO: Ensure that alphas are in [0,1]\n", " # TODO: Should we reset the color opacity to `op`?\n", " # Alternatively we can set `alphas = (1 - jnp.exp(-sigmas*1.0))` -- cf. Fuzzy Metaballs paper\n", " alphas = sigmas * (ts > 0)\n", " arrival_probs = discrete_arrival_probabilities(alphas)\n", - " op = 1 - arrival_probs[-1] # Opacity\n", - " mean_depth = jnp.sum(arrival_probs[:-1]*xs[:,2]) \\\n", - " + arrival_probs[-1]*zmax\n", - " mean_color = jnp.sum(arrival_probs[:-1,None]*colors[order], axis=0) \\\n", - " + arrival_probs[-1]*bg_color \n", + " op = 1 - arrival_probs[-1] # Opacity\n", + " mean_depth = jnp.sum(arrival_probs[:-1] * xs[:, 2]) + arrival_probs[-1] * zmax\n", + " mean_color = (\n", + " jnp.sum(arrival_probs[:-1, None] * colors[order], axis=0)\n", + " + arrival_probs[-1] * bg_color\n", + " )\n", "\n", " return mean_depth, mean_color, op\n", "\n", "\n", - "cast_rays = jit(vmap(_cast_ray, (0,None,None,None,None,None,None)))" + "cast_rays = jit(vmap(_cast_ray, (0, None, None, None, None, None, None)))" ] }, { @@ -715,12 +841,15 @@ "w = 100\n", "h = 100\n", "f = 300\n", - "intr = Intrinsics(width = w, height = h,fx = f, fy = f,\n", - " cx = w/2 - 0.5, cy = h/2 - 0.5, near = 1e-6, far = 5.0)\n", + "intr = Intrinsics(\n", + " width=w, height=h, fx=f, fy=f, cx=w / 2 - 0.5, cy=h / 2 - 0.5, near=1e-6, far=5.0\n", + ")\n", "\n", - "cam_pose = transform_from_pos_target_up(0.7*jnp.array([1.,0,0.5]), jnp.array([0,0.03,0]), jnp.array([0,1,0]))\n", - "cam_K = K_from_intrinsics(intr)\n", - "rays = camera_rays_from_intrinsics(intr)" + "cam_pose = transform_from_pos_target_up(\n", + " 0.7 * jnp.array([1.0, 0, 0.5]), jnp.array([0, 0.03, 0]), jnp.array([0, 1, 0])\n", + ")\n", + "cam_K = K_from_intrinsics(intr)\n", + "rays = camera_rays_from_intrinsics(intr)" ] }, { @@ -737,8 +866,8 @@ } ], "source": [ - "bit_to_mb = 1.25e-7\n", - "bit_to_gb = 1.25e-10\n", + "bit_to_mb = 1.25e-7\n", + "bit_to_gb = 1.25e-10\n", "n_gaussians = 100\n", "\n", "print(f\"{w*h*n_gaussians * 32 * bit_to_gb} GB\")" @@ -780,37 +909,46 @@ "from sklearn.mixture import GaussianMixture\n", "\n", "\n", - "data = jnp.load('data/gaussians_banana_1550.npz')\n", - "valid = data[\"nums\"] > 0 \n", - "mus = data[\"mus\"][valid]\n", - "covs = data[\"covs\"][valid]\n", + "data = jnp.load(\"data/gaussians_banana_1550.npz\")\n", + "valid = data[\"nums\"] > 0\n", + "mus = data[\"mus\"][valid]\n", + "covs = data[\"covs\"][valid]\n", "colors = data[\"mean_colors\"][valid]\n", "\n", "print(f\"\"\"\n", "{jnp.sum(valid)} of {len(valid)} Gaussians have data associated to them.\n", "\"\"\")\n", "\n", - "mus = (mus - cam_pose[:3,3]) @ cam_pose[:3,:3] # same as mapping `mu -> inv(cam_pose) @ mu``\n", - "covs = vmap(lambda cov: cam_pose[:3,:3] @ cov @ cam_pose[:3,:3].T)(covs) \n", - "precisions = vmap(inv)(2.**2*covs)\n", - "weights = jnp.ones(len(mus))\n", - "zmax = 5.0\n", - "bg_color = jnp.array([1.,1.,1.,1.])\n", + "mus = (mus - cam_pose[:3, 3]) @ cam_pose[\n", + " :3, :3\n", + "] # same as mapping `mu -> inv(cam_pose) @ mu``\n", + "covs = vmap(lambda cov: cam_pose[:3, :3] @ cov @ cam_pose[:3, :3].T)(covs)\n", + "precisions = vmap(inv)(2.0**2 * covs)\n", + "weights = jnp.ones(len(mus))\n", + "zmax = 5.0\n", + "bg_color = jnp.array([1.0, 1.0, 1.0, 1.0])\n", "\n", "\n", - "zs, cs, _ = cast_rays(rays.reshape(-1,3), mus, precisions, colors, weights, zmax, bg_color)\n", + "zs, cs, _ = cast_rays(\n", + " rays.reshape(-1, 3), mus, precisions, colors, weights, zmax, bg_color\n", + ")\n", "zs = zs.reshape(intr.height, intr.width)\n", "cs = cs.reshape(intr.height, intr.width, -1)\n", "\n", "\n", - "\n", - "gm = GaussianMixture(n_components=3).fit(zs.reshape(-1,1))\n", + "gm = GaussianMixture(n_components=3).fit(zs.reshape(-1, 1))\n", "# ==============================\n", - "fig, axs = plt.subplots(1,2, figsize=(4,4))\n", + "fig, axs = plt.subplots(1, 2, figsize=(4, 4))\n", "axs[0].set_title(\"RGB\")\n", "axs[0].imshow(cs, interpolation=\"nearest\")\n", "axs[1].set_title(\"Z-Depth\")\n", - "axs[1].imshow(zs, cmap=\"viridis_r\", vmin=z_mean-3*np.sqrt(z_var), vmax=z_mean+3*np.sqrt(z_var), interpolation=\"nearest\")\n", + "axs[1].imshow(\n", + " zs,\n", + " cmap=\"viridis_r\",\n", + " vmin=z_mean - 3 * np.sqrt(z_var),\n", + " vmax=z_mean + 3 * np.sqrt(z_var),\n", + " interpolation=\"nearest\",\n", + ")\n", "fig.tight_layout()" ] }, @@ -820,7 +958,7 @@ "metadata": {}, "outputs": [], "source": [ - "cloud = unproject_depth(zs[::-1,::-1], intr).reshape(-1,3)" + "cloud = unproject_depth(zs[::-1, ::-1], intr).reshape(-1, 3)" ] }, { @@ -841,22 +979,24 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "\n", "\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"spheres\"})\n", "msg.payload.data.MergeFrom(\n", - " traceviz.client.to_pytree_msg({\n", - " 'centers': np.array(cloud), \n", - " 'colors': np.array(cs[::-1,::-1].reshape(-1,4)), \n", - " \"scales\": 0.002*np.ones(len(cloud))\n", - " })\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": np.array(cloud),\n", + " \"colors\": np.array(cs[::-1, ::-1].reshape(-1, 4)),\n", + " \"scales\": 0.002 * np.ones(len(cloud)),\n", + " }\n", + " )\n", ")\n", - " \n", + "\n", "\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { diff --git a/scripts/_mkl/notebooks/07 - Gaussian Sensor Model.ipynb b/scripts/_mkl/notebooks/07 - Gaussian Sensor Model.ipynb index 4db91a9e..32e70b12 100644 --- a/scripts/_mkl/notebooks/07 - Gaussian Sensor Model.ipynb +++ b/scripts/_mkl/notebooks/07 - Gaussian Sensor Model.ipynb @@ -16,7 +16,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp gaussian_sensor_model" + "# |default_exp gaussian_sensor_model" ] }, { @@ -40,7 +40,7 @@ } ], "source": [ - "#|export\n", + "# |export\n", "import bayes3d as b3d\n", "import trimesh\n", "import os\n", @@ -52,12 +52,16 @@ "import jax.numpy as jnp\n", "from functools import partial\n", "from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics\n", - "from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth\n", + "from bayes3d.transforms_3d import (\n", + " transform_from_pos_target_up,\n", + " add_homogenous_ones,\n", + " unproject_depth,\n", + ")\n", "import tensorflow_probability as tfp\n", "from tensorflow_probability.substrates.jax.math import lambertw\n", "\n", - "normal_cdf = jax.scipy.stats.norm.cdf\n", - "normal_pdf = jax.scipy.stats.norm.pdf\n", + "normal_cdf = jax.scipy.stats.norm.cdf\n", + "normal_pdf = jax.scipy.stats.norm.pdf\n", "normal_logpdf = jax.scipy.stats.norm.logpdf\n", "inv = jnp.linalg.inv\n", "\n", @@ -71,7 +75,7 @@ "outputs": [], "source": [ "import traceviz.client\n", - "from traceviz.proto import viz_pb2\n", + "from traceviz.proto import viz_pb2\n", "import json\n", "import inspect\n", "from IPython.display import Markdown as md\n", @@ -84,7 +88,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from bayes3d._mkl.types import *" ] }, @@ -94,8 +98,8 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def ellipsoid_embedding(cov:CovarianceMatrix) -> Matrix:\n", + "# |export\n", + "def ellipsoid_embedding(cov: CovarianceMatrix) -> Matrix:\n", " \"\"\"Returns A with cov = A@A.T\"\"\"\n", " sigma, U = jnp.linalg.eigh(cov)\n", " D = jnp.diag(jnp.sqrt(sigma))\n", @@ -108,8 +112,8 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def bilinear(x:Array, y:Array, A:Matrix) -> Float:\n", + "# |export\n", + "def bilinear(x: Array, y: Array, A: Matrix) -> Float:\n", " return x.T @ A @ y" ] }, @@ -119,21 +123,21 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def log_gaussian(x:Vector, mu:Vector, P:PrecisionMatrix) -> Float:\n", + "# |export\n", + "def log_gaussian(x: Vector, mu: Vector, P: PrecisionMatrix) -> Float:\n", " \"\"\"Evaluate an **unnormalized** gaussian at a given point.\"\"\"\n", - " return -0.5 * bilinear(x-mu, x-mu, P)\n", + " return -0.5 * bilinear(x - mu, x - mu, P)\n", "\n", "\n", - "def gaussian(x:Vector, mu:Vector, P:PrecisionMatrix) -> Float:\n", + "def gaussian(x: Vector, mu: Vector, P: PrecisionMatrix) -> Float:\n", " \"\"\"Evaluate an **unnormalized** gaussian at a given point.\"\"\"\n", - " return jnp.exp(-0.5 * bilinear(x-mu, x-mu, P))\n", + " return jnp.exp(-0.5 * bilinear(x - mu, x - mu, P))\n", "\n", "\n", - "def gaussian_normalizing_constant(P:PrecisionMatrix) -> Float:\n", + "def gaussian_normalizing_constant(P: PrecisionMatrix) -> Float:\n", " \"\"\"Returns the normalizing constant of an unnormalized gaussian.\"\"\"\n", " n = P.shape[0]\n", - " return jnp.sqrt(jnp.linalg.det(P)/(2*jnp.pi)**n)" + " return jnp.sqrt(jnp.linalg.det(P) / (2 * jnp.pi) ** n)" ] }, { @@ -142,17 +146,19 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def gaussian_restriction_to_ray(loc:Vector, P:PrecisionMatrix, A:CholeskyMatrix, x:Vector, v:Direction):\n", + "# |export\n", + "def gaussian_restriction_to_ray(\n", + " loc: Vector, P: PrecisionMatrix, A: CholeskyMatrix, x: Vector, v: Direction\n", + "):\n", " \"\"\"\n", - " Restricts a gaussian to a ray and returns \n", - " the mean `mu` and standard deviation `std`, s.t. we have \n", + " Restricts a gaussian to a ray and returns\n", + " the mean `mu` and standard deviation `std`, s.t. we have\n", " $$\n", " P(x + t*v | loc, cov) = P(x + mu*v | loc, cov) * N(t | mu, std)\n", " $$\n", " \"\"\"\n", - " mu = bilinear(loc - x, v, P)/bilinear(v, v, P)\n", - " std = 1/jnp.linalg.norm(inv(A)@v)\n", + " mu = bilinear(loc - x, v, P) / bilinear(v, v, P)\n", + " std = 1 / jnp.linalg.norm(inv(A) @ v)\n", " return mu, std" ] }, @@ -169,17 +175,17 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def discrete_arrival_probabilities(occupancy_probs:Vector):\n", + "# |export\n", + "def discrete_arrival_probabilities(occupancy_probs: Vector):\n", " \"\"\"\n", - " Given an vector of `n` occupancy probabilities of neighbouring pixels, \n", - " it returns a vector of length `n+1` containing the probabilities of stopping \n", + " Given an vector of `n` occupancy probabilities of neighbouring pixels,\n", + " it returns a vector of length `n+1` containing the probabilities of stopping\n", " at a each pixel (while traversing them left to right) or not stopping at all.\n", "\n", " The return array is given by:\n", " $$\n", " q_i = p_i \\cdot \\prod_{j=0}^{i-1} (1 - p_j)\n", - " \n", + "\n", " $$\n", " for $i=0,...,n-1$, and\n", " $$\n", @@ -192,7 +198,9 @@ " X(T) = \\sigma(T)*\\exp(- \\int_0^T \\sigma(t) \\ dt).\n", " $$\n", " \"\"\"\n", - " transmittances = jnp.concatenate([jnp.array([1.0]), jnp.cumprod(1-occupancy_probs)])\n", + " transmittances = jnp.concatenate(\n", + " [jnp.array([1.0]), jnp.cumprod(1 - occupancy_probs)]\n", + " )\n", " extended_occupancies = jnp.concatenate([occupancy_probs, jnp.array([1.0])])\n", " return extended_occupancies * transmittances" ] @@ -215,15 +223,15 @@ ], "source": [ "key = keysplit(key)\n", - "occupancy_probs = 0.3*jax.random.uniform(key, (10,))\n", - "arrival_probs = discrete_arrival_probabilities(occupancy_probs)\n", + "occupancy_probs = 0.3 * jax.random.uniform(key, (10,))\n", + "arrival_probs = discrete_arrival_probabilities(occupancy_probs)\n", "\n", "assert jnp.isclose(arrival_probs.sum(), 1.0)\n", "\n", "# =======================\n", - "plt.figure(figsize=(3,1))\n", + "plt.figure(figsize=(3, 1))\n", "plt.plot(jnp.arange(len(occupancy_probs)), occupancy_probs)\n", - "plt.plot(jnp.arange(len(occupancy_probs)+1), arrival_probs);" + "plt.plot(jnp.arange(len(occupancy_probs) + 1), arrival_probs);" ] }, { @@ -232,7 +240,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def gaussian_time_of_arrival(xs, mu, sig, w=1.0):\n", " \"\"\"\n", " Time of first arrival for a **single** weighted 1-dimensional Gaussian, i.e. returns an array of\n", @@ -241,23 +249,28 @@ " Y(T) = w*g(T | \\mu, \\sigma)*\\exp(- \\int_0^T w*g(t | \\mu, \\sigma) \\ dt).\n", " $$\n", " \"\"\"\n", - " ys = w*normal_pdf(xs, loc=mu, scale=sig) * jnp.exp(\n", - " - w*normal_cdf(xs, loc=mu, scale=sig) \n", - " + w*normal_cdf(0.0, loc=mu, scale=sig))\n", - " return ys \n", - "\n", - "\n", - "def gaussian_most_likely_time_of_arrival(mu, sig, w=1.):\n", + " ys = (\n", + " w\n", + " * normal_pdf(xs, loc=mu, scale=sig)\n", + " * jnp.exp(\n", + " -w * normal_cdf(xs, loc=mu, scale=sig)\n", + " + w * normal_cdf(0.0, loc=mu, scale=sig)\n", + " )\n", + " )\n", + " return ys\n", + "\n", + "\n", + "def gaussian_most_likely_time_of_arrival(mu, sig, w=1.0):\n", " \"\"\"\n", " Returns the most likely time of first arrival\n", - " for a single weighted 1-dimensional Gaussian, i.e. the argmax of \n", + " for a single weighted 1-dimensional Gaussian, i.e. the argmax of\n", " $$\n", " Y(T) = w*g(T | \\mu, \\sigma)*\\exp(- \\int_0^T w*g(t | \\mu, \\sigma) \\ dt).\n", " $$\n", " \"\"\"\n", " # TODO: Check if this is correct, cf. my notes.\n", - " Z = jnp.sqrt(lambertw(1/(2*jnp.pi) * w**2))\n", - " return mu - Z*sig" + " Z = jnp.sqrt(lambertw(1 / (2 * jnp.pi) * w**2))\n", + " return mu - Z * sig" ] }, { @@ -289,13 +302,13 @@ } ], "source": [ - "key, keys = keysplit(key,1,3)\n", + "key, keys = keysplit(key, 1, 3)\n", "\n", - "xs = jnp.linspace(0,11,20_000)\n", - "mu = jax.random.uniform(keys[0],(), float, 3.0,8.0)\n", - "sig = jax.random.uniform(keys[1],(), float, 0.1,2.0)\n", - "w = jax.random.uniform(keys[2],(), float, 0.1,1.0)\n", - "t = gaussian_most_likely_time_of_arrival(mu, sig, w=w)\n", + "xs = jnp.linspace(0, 11, 20_000)\n", + "mu = jax.random.uniform(keys[0], (), float, 3.0, 8.0)\n", + "sig = jax.random.uniform(keys[1], (), float, 0.1, 2.0)\n", + "w = jax.random.uniform(keys[2], (), float, 0.1, 1.0)\n", + "t = gaussian_most_likely_time_of_arrival(mu, sig, w=w)\n", "\n", "print(f\"\"\"\n", " mu: {mu:0.3f}\n", @@ -305,16 +318,28 @@ "\"\"\")\n", "\n", "# =======================\n", - "plt.figure(figsize=(10,1))\n", - "plt.gca().spines['right'].set_visible(False)\n", - "plt.gca().spines['top'].set_visible(False)\n", - "plt.plot(xs, gaussian_time_of_arrival(xs, mu, sig, w), c=\"C1\", alpha=1., label=\"Time of first Arrival\")\n", + "plt.figure(figsize=(10, 1))\n", + "plt.gca().spines[\"right\"].set_visible(False)\n", + "plt.gca().spines[\"top\"].set_visible(False)\n", + "plt.plot(\n", + " xs,\n", + " gaussian_time_of_arrival(xs, mu, sig, w),\n", + " c=\"C1\",\n", + " alpha=1.0,\n", + " label=\"Time of first Arrival\",\n", + ")\n", "plt.fill_between(xs, gaussian_time_of_arrival(xs, mu, sig, w), color=\"C1\", alpha=0.1)\n", - "plt.vlines(t, 0, gaussian_time_of_arrival(t, mu, sig, w), color=\"C1\", alpha=.5)\n", + "plt.vlines(t, 0, gaussian_time_of_arrival(t, mu, sig, w), color=\"C1\", alpha=0.5)\n", "plt.scatter(t, gaussian_time_of_arrival(t, mu, sig, w), c=\"C1\")\n", - "plt.plot(xs, w*normal_pdf(xs, loc=mu, scale=sig), c=\"C0\", alpha=1., label=\"Weighted Gaussian\")\n", - "plt.vlines(mu, 0, w*normal_pdf(mu, loc=mu, scale=sig), color=\"C0\", alpha=.5)\n", - "plt.scatter(mu, w*normal_pdf(mu, loc=mu, scale=sig), c=\"C0\")\n", + "plt.plot(\n", + " xs,\n", + " w * normal_pdf(xs, loc=mu, scale=sig),\n", + " c=\"C0\",\n", + " alpha=1.0,\n", + " label=\"Weighted Gaussian\",\n", + ")\n", + "plt.vlines(mu, 0, w * normal_pdf(mu, loc=mu, scale=sig), color=\"C0\", alpha=0.5)\n", + "plt.scatter(mu, w * normal_pdf(mu, loc=mu, scale=sig), c=\"C0\")\n", "plt.legend();" ] }, @@ -331,15 +356,17 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def weighted_arrival_intersection(mu:Vector, P:PrecisionMatrix, A:CholeskyMatrix, w:Float, x:Vector, v:Direction):\n", + "# |export\n", + "def weighted_arrival_intersection(\n", + " mu: Vector, P: PrecisionMatrix, A: CholeskyMatrix, w: Float, x: Vector, v: Direction\n", + "):\n", " \"\"\"\n", " Returns the \"intersection\" of a ray with a gaussian which we define as\n", " the mode of the gaussian restricted to the ray.\n", " \"\"\"\n", " t0, sig0 = gaussian_restriction_to_ray(mu, P, A, x, v)\n", - " w0 = w*gaussian(t0*v, mu, P)\n", - " Z = w0/gaussian_normalizing_constant(P)\n", + " w0 = w * gaussian(t0 * v, mu, P)\n", + " Z = w0 / gaussian_normalizing_constant(P)\n", " t = gaussian_most_likely_time_of_arrival(t0, sig0, Z)\n", " return t, w0" ] @@ -350,14 +377,16 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "def weighted_argmax_intersection(mu:Vector, P:PrecisionMatrix, w:Float, x:Vector, v:Direction):\n", + "# |export\n", + "def weighted_argmax_intersection(\n", + " mu: Vector, P: PrecisionMatrix, w: Float, x: Vector, v: Direction\n", + "):\n", " \"\"\"\n", " Returns the \"intersection\" of a ray with a gaussian which we define as\n", " the mode of the gaussian restricted to the ray.\n", " \"\"\"\n", - " t = bilinear(mu - x, v, P)/bilinear(v, v, P)\n", - " return t, w*gaussian(x + t*v, mu, P)" + " t = bilinear(mu - x, v, P) / bilinear(v, v, P)\n", + " return t, w * gaussian(x + t * v, mu, P)" ] }, { @@ -373,9 +402,9 @@ "metadata": {}, "outputs": [], "source": [ - "THRESH_99 = gaussian(jnp.array([4,0,0]), jnp.zeros(3), jnp.eye(3))\n", - "THRESH_97 = gaussian(jnp.array([3,0,0]), jnp.zeros(3), jnp.eye(3))\n", - "THRESH_73 = gaussian(jnp.array([2,0,0]), jnp.zeros(3), jnp.eye(3))" + "THRESH_99 = gaussian(jnp.array([4, 0, 0]), jnp.zeros(3), jnp.eye(3))\n", + "THRESH_97 = gaussian(jnp.array([3, 0, 0]), jnp.zeros(3), jnp.eye(3))\n", + "THRESH_73 = gaussian(jnp.array([2, 0, 0]), jnp.zeros(3), jnp.eye(3))" ] }, { @@ -395,7 +424,7 @@ } ], "source": [ - "order = jnp.array([2,0,3,1])\n", + "order = jnp.array([2, 0, 3, 1])\n", "jnp.arange(4)[order]\n", "jnp.argsort(order)" ] @@ -406,44 +435,53 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def _particle_arrival_probs(\n", - " v: \"Ray direction\", \n", - " particle_positions, \n", - " precisions, \n", - " embeddings, \n", - " weights): \n", + " v: \"Ray direction\", particle_positions, precisions, embeddings, weights\n", + "):\n", " \"\"\"\n", - " Returns the arrival particles for each particles along a ray \n", + " Returns the arrival particles for each particles along a ray\n", " (in the order of the particles).\n", " \"\"\"\n", " # TODO: Deal with negative intersections behind the camera\n", " # TODO: Maybe switch to log probs?\n", "\n", - " # Compute fuzzy intersections `xs` with Gaussians and \n", + " # Compute fuzzy intersections `xs` with Gaussians and\n", " # their function values `sigmas`\n", " origin = jnp.zeros(3)\n", - " ts, sigmas = vmap(weighted_arrival_intersection, (0,0,0,0,None,None))(\n", - " particle_positions, precisions, embeddings, weights, origin, v)\n", + " ts, sigmas = vmap(weighted_arrival_intersection, (0, 0, 0, 0, None, None))(\n", + " particle_positions, precisions, embeddings, weights, origin, v\n", + " )\n", + "\n", + " order = jnp.argsort(ts)\n", "\n", - " order = jnp.argsort(ts)\n", - " \n", - " ts = ts[order]\n", + " ts = ts[order]\n", " sigmas = sigmas[order]\n", - " xs = ts[:,None]*v[None,:]\n", + " xs = ts[:, None] * v[None, :]\n", "\n", " # TODO: Ensure that alphas are in [0,1]\n", " # TODO: Should we reset the color opacity to `op`?\n", " # Alternatively we can set `alphas = (1 - jnp.exp(-sigmas*1.0))` -- cf. Fuzzy Metaballs paper\n", " alphas = sigmas * (ts > 0)\n", " arrival_probs = discrete_arrival_probabilities(alphas)\n", - " op = 1 - arrival_probs[-1] # Opacity\n", + " op = 1 - arrival_probs[-1] # Opacity\n", " inverse_order = jnp.concatenate([jnp.argsort(order), jnp.array([len(order)])])\n", "\n", " return arrival_probs[inverse_order]\n", "\n", "\n", - "particle_arrival_probs = jit(vmap(_particle_arrival_probs, (0,None,None,None,None,)))" + "particle_arrival_probs = jit(\n", + " vmap(\n", + " _particle_arrival_probs,\n", + " (\n", + " 0,\n", + " None,\n", + " None,\n", + " None,\n", + " None,\n", + " ),\n", + " )\n", + ")" ] }, { @@ -459,36 +497,30 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def make_sensor_model(rays):\n", - "\n", " @genjax.Static\n", " def sensor_model(\n", - " cam_K, \n", - " particle_positions, \n", - " precisions, \n", - " embeddings, \n", - " weights, \n", - " colors, \n", - " mask, \n", - " bg_color=jnp.array([1.,1.,1.,1.])):\n", - "\n", - " R = cam_K[:3,:3]\n", - " x = cam_K[:3,3]\n", - " positions = (particle_positions - x) @ R # same as mapping `mu -> inv(cam_pose) @ mu`\n", + " cam_K,\n", + " particle_positions,\n", + " precisions,\n", + " embeddings,\n", + " weights,\n", + " colors,\n", + " mask,\n", + " bg_color=jnp.array([1.0, 1.0, 1.0, 1.0]),\n", + " ):\n", + " R = cam_K[:3, :3]\n", + " x = cam_K[:3, 3]\n", + " positions = (\n", + " particle_positions - x\n", + " ) @ R # same as mapping `mu -> inv(cam_pose) @ mu`\n", " precisions = vmap(lambda P: R.T @ P @ R)(precisions)\n", " embeddings = vmap(lambda E: R.T @ E @ R)(embeddings)\n", - " \n", - " probs = particle_arrival_probs(\n", - " rays.reshape(-1,3),\n", - " positions, \n", - " precisions, \n", - " embeddings, \n", - " weights)\n", - "\n", - " \n", - "\n", "\n", + " probs = particle_arrival_probs(\n", + " rays.reshape(-1, 3), positions, precisions, embeddings, weights\n", + " )\n", "\n", " return sensor_model" ] @@ -502,12 +534,15 @@ "w = 100\n", "h = 100\n", "f = 300\n", - "intr = Intrinsics(width = w, height = h,fx = f, fy = f,\n", - " cx = w/2 - 0.5, cy = h/2 - 0.5, near = 1e-6, far = 5.0)\n", - "\n", - "cam_pose = transform_from_pos_target_up(.5*jnp.array([.1, 0., 1.5]), jnp.array([0,0.0,0]), jnp.array([0,1,0]))\n", - "cam_K = K_from_intrinsics(intr)\n", - "rays = camera_rays_from_intrinsics(intr)" + "intr = Intrinsics(\n", + " width=w, height=h, fx=f, fy=f, cx=w / 2 - 0.5, cy=h / 2 - 0.5, near=1e-6, far=5.0\n", + ")\n", + "\n", + "cam_pose = transform_from_pos_target_up(\n", + " 0.5 * jnp.array([0.1, 0.0, 1.5]), jnp.array([0, 0.0, 0]), jnp.array([0, 1, 0])\n", + ")\n", + "cam_K = K_from_intrinsics(intr)\n", + "rays = camera_rays_from_intrinsics(intr)" ] }, { @@ -523,9 +558,9 @@ "data = jnp.load(f\"data/gaussian_examples/gaussians_{name}_{num_components}.npz\")\n", "\n", "\n", - "mus = data[\"mus\"]\n", - "covs = data[\"covs\"]\n", - "cols = data[\"colors\"]\n" + "mus = data[\"mus\"]\n", + "covs = data[\"covs\"]\n", + "cols = data[\"colors\"]" ] }, { @@ -562,31 +597,32 @@ } ], "source": [ - "positions = (mus - cam_pose[:3,3]) @ cam_pose[:3,:3] # same as mapping `mu -> inv(cam_pose) @ mu`\n", - "covariances = 2.**2*vmap(lambda cov: cam_pose[:3,:3].T@cov@cam_pose[:3,:3])(covs)\n", - "embeddings = vmap(ellipsoid_embedding)(covariances)\n", - "precisions = vmap(inv)(covariances)\n", - "weights = .25*jnp.ones(len(positions))\n", - "bg_color = jnp.array([1.,1.,1.,1.])\n", + "positions = (mus - cam_pose[:3, 3]) @ cam_pose[\n", + " :3, :3\n", + "] # same as mapping `mu -> inv(cam_pose) @ mu`\n", + "covariances = 2.0**2 * vmap(lambda cov: cam_pose[:3, :3].T @ cov @ cam_pose[:3, :3])(\n", + " covs\n", + ")\n", + "embeddings = vmap(ellipsoid_embedding)(covariances)\n", + "precisions = vmap(inv)(covariances)\n", + "weights = 0.25 * jnp.ones(len(positions))\n", + "bg_color = jnp.array([1.0, 1.0, 1.0, 1.0])\n", "\n", "\n", "colors = jnp.concatenate([cols, bg_color[None]], axis=0)\n", "\n", "probs = particle_arrival_probs(\n", - " rays.reshape(-1,3),\n", - " positions, \n", - " precisions, \n", - " embeddings, \n", - " weights)\n", + " rays.reshape(-1, 3), positions, precisions, embeddings, weights\n", + ")\n", "\n", "\n", "I = jnp.argmax(probs, axis=-1)\n", "\n", - "avg = (probs[:,:,None] * colors[None]).sum(-2)\n", + "avg = (probs[:, :, None] * colors[None]).sum(-2)\n", "\n", - "fig, axs = plt.subplots(1,2, figsize=(10,5))\n", - "axs[0].imshow(colors[I].reshape(h,w,-1))\n", - "axs[1].imshow(avg.reshape(h,w,-1))\n", + "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", + "axs[0].imshow(colors[I].reshape(h, w, -1))\n", + "axs[1].imshow(avg.reshape(h, w, -1))\n", "probs.shape" ] }, @@ -610,31 +646,40 @@ "from bayes3d._mkl.trimesh_to_gaussians import pack_transform, ellipsoid_embedding\n", "import traceviz.client\n", "import numpy as np\n", - "from traceviz.proto import viz_pb2\n", + "from traceviz.proto import viz_pb2\n", "import json\n", "\n", "\n", - "transforms = vmap(pack_transform, (0,0,None))(\n", - " positions - positions.mean(axis=0,keepdims=True), \n", - " embeddings, \n", - " 1.0)\n", + "transforms = vmap(pack_transform, (0, 0, None))(\n", + " positions - positions.mean(axis=0, keepdims=True), embeddings, 1.0\n", + ")\n", "\n", "\n", "msg = viz_pb2.Message()\n", - "msg.pytree.MergeFrom(traceviz.client.to_pytree_msg({\"type\": \"setup\",}))\n", + "msg.pytree.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"type\": \"setup\",\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "\n", "msg = viz_pb2.Message()\n", - "msg.pytree.MergeFrom(traceviz.client.to_pytree_msg({\n", - " \"type\": \"gaussians\",\n", - " \"data\": {\n", - " \"transforms\": np.array(transforms[None]),\n", - " 'colors': np.array(cols[None])\n", + "msg.pytree.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"type\": \"gaussians\",\n", + " \"data\": {\n", + " \"transforms\": np.array(transforms[None]),\n", + " \"colors\": np.array(cols[None]),\n", + " },\n", " }\n", - " }))\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -644,7 +689,7 @@ "outputs": [], "source": [ "def logit(x):\n", - " return jnp.log(x/(1-x))" + " return jnp.log(x / (1 - x))" ] }, { @@ -665,40 +710,39 @@ "from PIL import Image\n", "import io\n", "\n", + "\n", "def fig_to_image(fig):\n", " \"\"\"Convert a Matplotlib figure to a PIL Image and return it\"\"\"\n", " buf = io.BytesIO()\n", - " fig.savefig(buf, format='png')\n", + " fig.savefig(buf, format=\"png\")\n", " buf.seek(0)\n", " img = Image.open(buf)\n", " return img\n", "\n", - " \n", + "\n", "images = []\n", "\n", - "avg = jnp.zeros((h,w,4))\n", + "avg = jnp.zeros((h, w, 4))\n", "\n", "num_samples = 40\n", "for t in range(num_samples):\n", - "\n", " key = keysplit(key)\n", " I = jax.random.categorical(key, logit(probs))\n", - " im = colors[I].reshape(h,w,-1)\n", + " im = colors[I].reshape(h, w, -1)\n", " avg += im\n", "\n", - "\n", - " fig, axs = plt.subplots(1,2)\n", + " fig, axs = plt.subplots(1, 2)\n", " axs[0].set_title(f\"Sample ({t+1}/{num_samples})\")\n", " axs[0].imshow(im)\n", "\n", " axs[1].set_title(\"Sample mean\")\n", - " axs[1].imshow(avg/(t+1))\n", + " axs[1].imshow(avg / (t + 1))\n", "\n", " axs[0].axis(\"off\")\n", " axs[1].axis(\"off\")\n", " # Convert figure to image and store\n", " images.append(fig_to_image(fig))\n", - " plt.close(fig)\n" + " plt.close(fig)" ] }, { @@ -728,7 +772,7 @@ "\n", "# Save as GIF in memory\n", "gif_io = io.BytesIO()\n", - "imageio.mimsave(gif_io, np_images, format='gif', fps=10, loop=1)\n", + "imageio.mimsave(gif_io, np_images, format=\"gif\", fps=10, loop=1)\n", "gif_io.seek(0)\n", "\n", "# Display the GIF in the notebook\n", @@ -736,7 +780,7 @@ "\n", "# Save the GIF to disk\n", "gif_filename = f\"_outputs/ani_{name}_2.gif\"\n", - "with open(gif_filename, 'wb') as f:\n", + "with open(gif_filename, \"wb\") as f:\n", " f.write(gif_io.getbuffer())\n", "\n", "display(Image(url=gif_filename))" diff --git a/scripts/_mkl/notebooks/08c - Gaussian particle system - Genjax Minimal.ipynb b/scripts/_mkl/notebooks/08c - Gaussian particle system - Genjax Minimal.ipynb index c124029a..4222c04a 100644 --- a/scripts/_mkl/notebooks/08c - Gaussian particle system - Genjax Minimal.ipynb +++ b/scripts/_mkl/notebooks/08c - Gaussian particle system - Genjax Minimal.ipynb @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp gaussian_particle_system_genjax" + "# |default_exp gaussian_particle_system_genjax" ] }, { @@ -34,7 +34,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import bayes3d as b3d\n", "import trimesh\n", "import os\n", @@ -49,13 +49,17 @@ "from functools import partial\n", "import genjax\n", "from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics\n", - "from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth\n", + "from bayes3d.transforms_3d import (\n", + " transform_from_pos_target_up,\n", + " add_homogenous_ones,\n", + " unproject_depth,\n", + ")\n", "import tensorflow_probability as tfp\n", "from tensorflow_probability.substrates.jax.math import lambertw\n", "from typing import Any, NamedTuple\n", "\n", - "normal_cdf = jax.scipy.stats.norm.cdf\n", - "normal_pdf = jax.scipy.stats.norm.pdf\n", + "normal_cdf = jax.scipy.stats.norm.cdf\n", + "normal_pdf = jax.scipy.stats.norm.pdf\n", "normal_logpdf = jax.scipy.stats.norm.logpdf\n", "inv = jnp.linalg.inv\n", "concat = jnp.concatenate\n", @@ -69,7 +73,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "Array = np.ndarray | jax.Array\n", "Shape = int | tuple[int, ...]\n", "Bool = Array\n", @@ -84,12 +88,14 @@ "outputs": [], "source": [ "def logit(x):\n", - " return jnp.log(x/(1-x))\n", + " return jnp.log(x / (1 - x))\n", "\n", - " \n", - "def pack_homogenous_matrix(x: \"Position\",q: \"Quaternion\") -> \"HomogenousMatrix\":\n", + "\n", + "def pack_homogenous_matrix(x: \"Position\", q: \"Quaternion\") -> \"HomogenousMatrix\":\n", " r = Rot.from_quat(q).as_matrix()\n", - " return concat([concat([r, x.reshape(-1,1)], axis=-1), jnp.array([[0.,0.,0.,1.]])])" + " return concat(\n", + " [concat([r, x.reshape(-1, 1)], axis=-1), jnp.array([[0.0, 0.0, 0.0, 1.0]])]\n", + " )" ] }, { @@ -99,7 +105,6 @@ "outputs": [], "source": [ "def multiply_quaternions(q1, q2):\n", - "\n", " w1, x1, y1, z1 = q1[3], q1[0], q1[1], q1[2]\n", " w2, x2, y2, z2 = q2[3], q2[0], q2[1], q2[2]\n", "\n", @@ -118,13 +123,12 @@ "outputs": [], "source": [ "def make_constant_model(x, address=\"_ignored\"):\n", - "\n", " @genjax.Static\n", " def constant_model(*args):\n", " # Note that genjax's bernoulli takes logits!\n", " _ = genjax.bernoulli(jnp.inf) @ address\n", " return x\n", - " \n", + "\n", " return constant_model" ] }, @@ -142,21 +146,21 @@ "outputs": [], "source": [ "class Pose(NamedTuple):\n", - " position: Array\n", + " position: Array\n", " quaternion: Array\n", "\n", "\n", "class Cam(NamedTuple):\n", - " pose: Pose\n", + " pose: Pose\n", " intrinsics: Array\n", - " \n", + "\n", "\n", "class ParticleSystem(NamedTuple):\n", - " poses: tuple[Array, Array]\n", + " poses: tuple[Array, Array]\n", " covariances: Array\n", - " colors: Array\n", - " weights: Array\n", - " mask: Array\n", + " colors: Array\n", + " weights: Array\n", + " mask: Array\n", "\n", "\n", "@genjax.Static\n", @@ -171,42 +175,62 @@ " \"\"\"Samples pose with a position within given bounds.\"\"\"\n", " x = genjax.uniform(position_bounds[0], position_bounds[1]) @ \"x\"\n", " q = genjax.normal(jnp.zeros(4), jnp.ones(4)) @ \"q\"\n", - " q = q/jnp.linalg.norm(q)\n", + " q = q / jnp.linalg.norm(q)\n", " return Pose(x, q)\n", "\n", "\n", "def make_gps_prior(max_particles: int):\n", - "\n", " @genjax.Static\n", " def gps_prior(particle_bounds, embedding_bounds):\n", " \"\"\"Naive prior over Gaussian particle systems.\"\"\"\n", "\n", - " particle_mask = genjax.Map(genjax.bernoulli, in_axes=(0,))(\n", - " 0.5*jnp.ones(max_particles))@ \"mask\"\n", - " \n", - " poses = genjax.Map(genjax.masking_combinator(pose_prior), in_axes=(0,(0,)))(\n", - " particle_mask,\n", - " (jnp.tile(particle_bounds, (max_particles,1,1)),)) @ \"poses\"\n", - " \n", - " embs = genjax.Map(genjax.masking_combinator(gaussian_embedding_prior), in_axes=(0,(0,)))(\n", - " particle_mask,\n", - " (jnp.tile(embedding_bounds, (max_particles,1)),)) @ \"embedding_matrices\"\n", - " covs = vmap(lambda emb: emb@emb.T)(embs.value)\n", - "\n", - " cols = genjax.Map(genjax.masking_combinator(genjax.uniform), in_axes=(0,(0,0)))(\n", - " particle_mask,\n", - " (jnp.zeros((max_particles, 3)), jnp.ones((max_particles, 3)),)) @ \"colors\"\n", - "\n", - "\n", - " alphas = genjax.Map(genjax.masking_combinator(genjax.uniform), in_axes=(0,(0,0)))(\n", - " particle_mask,\n", - " (jnp.zeros(max_particles), jnp.ones(max_particles),)) @ \"transparencies\"\n", - "\n", - "\n", - " return ParticleSystem(poses.value, covs, cols.value, alphas.value, particle_mask)\n", - "\n", - "\n", - " return gps_prior\n" + " particle_mask = (\n", + " genjax.Map(genjax.bernoulli, in_axes=(0,))(0.5 * jnp.ones(max_particles))\n", + " @ \"mask\"\n", + " )\n", + "\n", + " poses = (\n", + " genjax.Map(genjax.masking_combinator(pose_prior), in_axes=(0, (0,)))(\n", + " particle_mask, (jnp.tile(particle_bounds, (max_particles, 1, 1)),)\n", + " )\n", + " @ \"poses\"\n", + " )\n", + "\n", + " embs = (\n", + " genjax.Map(\n", + " genjax.masking_combinator(gaussian_embedding_prior), in_axes=(0, (0,))\n", + " )(particle_mask, (jnp.tile(embedding_bounds, (max_particles, 1)),))\n", + " @ \"embedding_matrices\"\n", + " )\n", + " covs = vmap(lambda emb: emb @ emb.T)(embs.value)\n", + "\n", + " cols = (\n", + " genjax.Map(genjax.masking_combinator(genjax.uniform), in_axes=(0, (0, 0)))(\n", + " particle_mask,\n", + " (\n", + " jnp.zeros((max_particles, 3)),\n", + " jnp.ones((max_particles, 3)),\n", + " ),\n", + " )\n", + " @ \"colors\"\n", + " )\n", + "\n", + " alphas = (\n", + " genjax.Map(genjax.masking_combinator(genjax.uniform), in_axes=(0, (0, 0)))(\n", + " particle_mask,\n", + " (\n", + " jnp.zeros(max_particles),\n", + " jnp.ones(max_particles),\n", + " ),\n", + " )\n", + " @ \"transparencies\"\n", + " )\n", + "\n", + " return ParticleSystem(\n", + " poses.value, covs, cols.value, alphas.value, particle_mask\n", + " )\n", + "\n", + " return gps_prior" ] }, { @@ -216,11 +240,14 @@ "outputs": [], "source": [ "gps_prior = make_gps_prior(10)\n", - "tr = gps_prior.simulate(key, (\n", - " jnp.array([jnp.zeros(3),jnp.ones(3)]), \n", - " jnp.array([0.,1.]),\n", - "))\n", - "gps = tr.get_retval() " + "tr = gps_prior.simulate(\n", + " key,\n", + " (\n", + " jnp.array([jnp.zeros(3), jnp.ones(3)]),\n", + " jnp.array([0.0, 1.0]),\n", + " ),\n", + ")\n", + "gps = tr.get_retval()" ] }, { @@ -237,7 +264,7 @@ "outputs": [], "source": [ "def make_observation_model():\n", - " return make_constant_model(jnp.array([0.,0.,0.,1.]), address=\"placeholder\") " + " return make_constant_model(jnp.array([0.0, 0.0, 0.0, 1.0]), address=\"placeholder\")" ] }, { @@ -247,30 +274,25 @@ "outputs": [], "source": [ "class Clustering(NamedTuple):\n", - " poses: tuple[Array, Array]\n", + " poses: tuple[Array, Array]\n", " assignments: Array\n", "\n", "\n", "@genjax.Static\n", "def motion_model(p: Pose, std_position, std_quaternion):\n", " \"\"\"Hacked motion model for elements in SE(3).\"\"\"\n", - " x = genjax.normal(p.position, std_position ) @ \"x\"\n", + " x = genjax.normal(p.position, std_position) @ \"x\"\n", " q = genjax.normal(p.quaternion, std_quaternion) @ \"q\"\n", - " q = q/jnp.linalg.norm(q)\n", - " return Pose(x,q)\n", + " q = q / jnp.linalg.norm(q)\n", + " return Pose(x, q)\n", "\n", "\n", "def make_hgps_model(\n", - " max_clusters, \n", - " max_particles, \n", - " max_time_steps=10, \n", - " camera_intrinsics=jnp.array([0])):\n", - " \n", - " \n", - " gps_prior = make_gps_prior(max_particles)\n", + " max_clusters, max_particles, max_time_steps=10, camera_intrinsics=jnp.array([0])\n", + "):\n", + " gps_prior = make_gps_prior(max_particles)\n", " observation_model = make_observation_model()\n", "\n", - "\n", " @genjax.Static\n", " def kernel(state, camera_poses):\n", " t, gps, clustering = state\n", @@ -278,62 +300,70 @@ " # Cluster motion\n", " # ---------------\n", " # TODO: should empty clusters be masked out?\n", - " new_cluster_poses = genjax.Map(motion_model, in_axes=(0,None,None))(\n", - " clustering.poses, jnp.ones(3), jnp.ones(4)) @ \"cluster_poses\"\n", + " new_cluster_poses = (\n", + " genjax.Map(motion_model, in_axes=(0, None, None))(\n", + " clustering.poses, jnp.ones(3), jnp.ones(4)\n", + " )\n", + " @ \"cluster_poses\"\n", + " )\n", "\n", " # Intrinsic motion of particles within their clusters\n", " # ----------------------------------------------------\n", - " new_particle_poses = genjax.Map(genjax.masking_combinator(motion_model), in_axes=(0,(0,None,None)))(\n", - " gps.mask,\n", - " (gps.poses, jnp.ones(3), jnp.ones(4))) @ \"relative_particle_poses\"\n", + " new_particle_poses = (\n", + " genjax.Map(\n", + " genjax.masking_combinator(motion_model), in_axes=(0, (0, None, None))\n", + " )(gps.mask, (gps.poses, jnp.ones(3), jnp.ones(4)))\n", + " @ \"relative_particle_poses\"\n", + " )\n", " new_particle_poses = new_particle_poses.value\n", " # NOTE: Put \"energy\" constraint on particle system, eg., it should be more expensive\n", " # to move the relative pose of particles than moving the cluster. This could be done by\n", " # computing the relative pose updates and sampling a value from say a normal distribution.\n", - " # Constraining this value to be zero would put more weight on the updates to be zero. Should \n", - " # we put a prior over the contribution of that? \n", - " # NOTE: Should the particle motion model be something like N(0,I)*p(x_t|x_{t-1})? This would \n", + " # Constraining this value to be zero would put more weight on the updates to be zero. Should\n", + " # we put a prior over the contribution of that?\n", + " # NOTE: Should the particle motion model be something like N(0,I)*p(x_t|x_{t-1})? This would\n", " # put a constraint on the relative motion of particles.\n", "\n", " # Observation model\n", " # ------------------\n", - " # TODO/NOTE: There are 3 versions for an observation model: \n", + " # TODO/NOTE: There are 3 versions for an observation model:\n", " # 1. Use the weighted particles as they are to generate 3d point samples *without* going through a renderer.\n", " # 2. Uses the camera to re-reweight the particles and then uses the re-weighted particles to generate 3d point samples.\n", " # 3. Uses a renderer to generate a 2d-image observation.\n", " obs = observation_model(camera_poses[t], gps) @ \"observation\"\n", "\n", - "\n", - " gps = gps._replace(poses=new_particle_poses)\n", + " gps = gps._replace(poses=new_particle_poses)\n", " clustering = clustering._replace(poses=new_cluster_poses)\n", "\n", - " return (t+1, gps, clustering)\n", + " return (t + 1, gps, clustering)\n", "\n", - " \n", " unfolded_kernel = genjax.Unfold(kernel, max_time_steps)\n", "\n", - "\n", " @genjax.Static\n", " def hgps_model(T, particle_bounds, embedding_bounds, camera_poses):\n", - "\n", " gps = gps_prior(particle_bounds, embedding_bounds) @ \"initial_particle_system\"\n", "\n", - "\n", - " zs = genjax.Map(genjax.masking_combinator(genjax.categorical), in_axes=(0,(0,)))(\n", - " gps.mask,\n", - " (jnp.tile(jnp.ones(max_clusters), (max_particles, 1)),)) @ \"initial_assignments\"\n", + " zs = (\n", + " genjax.Map(\n", + " genjax.masking_combinator(genjax.categorical), in_axes=(0, (0,))\n", + " )(gps.mask, (jnp.tile(jnp.ones(max_clusters), (max_particles, 1)),))\n", + " @ \"initial_assignments\"\n", + " )\n", " zs = zs.value\n", "\n", " # TODO: should empty clusters be masked out?\n", - " qs = genjax.Map(pose_prior, in_axes=(0,))(\n", - " jnp.tile(particle_bounds, (max_clusters,1,1))) @ \"initial_coordinate_frames\"\n", + " qs = (\n", + " genjax.Map(pose_prior, in_axes=(0,))(\n", + " jnp.tile(particle_bounds, (max_clusters, 1, 1))\n", + " )\n", + " @ \"initial_coordinate_frames\"\n", + " )\n", " clustering = Clustering(qs, zs)\n", - " \n", + "\n", " state0 = (0, gps, clustering)\n", " states = unfolded_kernel(T, state0, camera_poses) @ \"chain\"\n", "\n", " return states\n", - " \n", "\n", " return hgps_model" ] @@ -344,15 +374,15 @@ "metadata": {}, "outputs": [], "source": [ - "max_clusters = 5\n", + "max_clusters = 5\n", "max_particles = 7\n", - "max_T = 10\n", + "max_T = 10\n", "hgps_model = make_hgps_model(max_clusters, max_particles, max_T)\n", "\n", "T = 9\n", - "camera_poses = jnp.tile(jnp.eye(4), (max_T,1,1))\n", - "particle_bounds = jnp.array([-jnp.ones(3), jnp.ones(3)])\n", - "embedding_bounds = jnp.array([0.,1.])\n", + "camera_poses = jnp.tile(jnp.eye(4), (max_T, 1, 1))\n", + "particle_bounds = jnp.array([-jnp.ones(3), jnp.ones(3)])\n", + "embedding_bounds = jnp.array([0.0, 1.0])\n", "\n", "tr = hgps_model.simulate(key, (T, particle_bounds, embedding_bounds, camera_poses))" ] diff --git a/scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb b/scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb index b2dc27e3..1c577a09 100644 --- a/scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb +++ b/scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb @@ -23,7 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp simple_likelihood" + "# |default_exp simple_likelihood" ] }, { @@ -32,12 +32,12 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import jax\n", "import jax.numpy as jnp\n", - "from jax import jit, vmap\n", + "from jax import jit, vmap\n", "import genjax\n", - "from genjax import gen, choice_map, vector_choice_map\n", + "from genjax import gen, choice_map, vector_choice_map\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import bayes3d\n", @@ -52,38 +52,39 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import genjax._src.generative_functions.distributions.tensorflow_probability as gentfp\n", "import tensorflow_probability.substrates.jax as tfp\n", + "\n", "tfd = tfp.distributions\n", "\n", "uniform = genjax.tfp_uniform\n", "\n", "truncnormal = gentfp.TFPDistribution(\n", - " lambda mu, sig, low, high: tfd.TruncatedNormal(mu, sig, low, high));\n", - "\n", - "normal = gentfp.TFPDistribution(\n", - " lambda mu, sig: tfd.Normal(mu, sig));\n", - "\n", + " lambda mu, sig, low, high: tfd.TruncatedNormal(mu, sig, low, high)\n", + ")\n", + "normal = gentfp.TFPDistribution(lambda mu, sig: tfd.Normal(mu, sig))\n", "diagnormal = gentfp.TFPDistribution(\n", - " lambda mus, sigs: tfd.MultivariateNormalDiag(mus, sigs));\n", - "\n", - "\n", + " lambda mus, sigs: tfd.MultivariateNormalDiag(mus, sigs)\n", + ")\n", "mixture_of_diagnormals = gentfp.TFPDistribution(\n", " lambda ws, mus, sig: tfd.MixtureSameFamily(\n", - " tfd.Categorical(ws),\n", - " tfd.MultivariateNormalDiag(mus, sig * jnp.ones_like(mus))))\n", + " tfd.Categorical(ws), tfd.MultivariateNormalDiag(mus, sig * jnp.ones_like(mus))\n", + " )\n", + ")\n", "\n", "mixture_of_normals = gentfp.TFPDistribution(\n", " lambda ws, mus, sig: tfd.MixtureSameFamily(\n", - " tfd.Categorical(ws),\n", - " tfd.Normal(mus, sig * jnp.ones_like(mus))))\n", + " tfd.Categorical(ws), tfd.Normal(mus, sig * jnp.ones_like(mus))\n", + " )\n", + ")\n", "\n", "\n", "mixture_of_truncnormals = gentfp.TFPDistribution(\n", " lambda ws, mus, sigs, lows, highs: tfd.MixtureSameFamily(\n", - " tfd.Categorical(ws),\n", - " tfd.TruncatedNormal(mus, sigs, lows, highs)))" + " tfd.Categorical(ws), tfd.TruncatedNormal(mus, sigs, lows, highs)\n", + " )\n", + ")" ] }, { @@ -92,12 +93,12 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from scipy.stats import truncnorm as scipy_truncnormal\n", "\n", - "normal_logpdf = jax.scipy.stats.norm.logpdf\n", + "normal_logpdf = jax.scipy.stats.norm.logpdf\n", "truncnorm_logpdf = jax.scipy.stats.truncnorm.logpdf\n", - "truncnorm_pdf = jax.scipy.stats.truncnorm.pdf\n" + "truncnorm_pdf = jax.scipy.stats.truncnorm.pdf" ] }, { @@ -123,41 +124,50 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", - "# TODO: The input Y should be an array only containing range measruements as well. \n", + "# |export\n", + "# TODO: The input Y should be an array only containing range measruements as well.\n", "# For this to work we need to have the pixel vectors (the rays through each pixel)\n", "\n", + "\n", "def make_simple_sensor_model(zmax):\n", - " \"\"\"Returns an simple sensor model marginalized over outliers.\"\"\" \n", + " \"\"\"Returns an simple sensor model marginalized over outliers.\"\"\"\n", "\n", " @genjax.drop_arguments\n", " @genjax.gen\n", " def _sensor_model(y, sig, outlier):\n", - " \n", - " \n", " # Compute max range along ray ending at far plane\n", " # and adding some wiggle room\n", " z_ = jnp.linalg.norm(y)\n", - " zmax_ = z_/y[2]*zmax\n", + " zmax_ = z_ / y[2] * zmax\n", "\n", - " inlier_outlier_mix = genjax.tfp_mixture(genjax.tfp_categorical, [truncnormal, genjax.tfp_uniform])\n", - " z = inlier_outlier_mix([jnp.log(1.0-outlier), jnp.log(outlier)], (\n", - " (z_, sig, 0.0, zmax_), \n", - " (0.0, zmax_ + 1e-6))) @ \"measurement\"\n", + " inlier_outlier_mix = genjax.tfp_mixture(\n", + " genjax.tfp_categorical, [truncnormal, genjax.tfp_uniform]\n", + " )\n", + " z = (\n", + " inlier_outlier_mix(\n", + " [jnp.log(1.0 - outlier), jnp.log(outlier)],\n", + " ((z_, sig, 0.0, zmax_), (0.0, zmax_ + 1e-6)),\n", + " )\n", + " @ \"measurement\"\n", + " )\n", "\n", " z = jnp.clip(z, 0.0, zmax_)\n", "\n", - " return z * y/z_\n", + " return z * y / z_\n", "\n", - " \n", " @genjax.gen\n", - " def sensor_model(Y, sig, out): \n", + " def sensor_model(Y, sig, out):\n", " \"\"\"\n", - " Simplest sensor model that returns a vector of range measurements conditioned on \n", + " Simplest sensor model that returns a vector of range measurements conditioned on\n", " an image, noise level, and outlier probability.\n", " \"\"\"\n", - " \n", - " X = genjax.Map(_sensor_model, (0,None,None))(Y[...,:3].reshape(-1,3), sig, out) @ \"X\"\n", + "\n", + " X = (\n", + " genjax.Map(_sensor_model, (0, None, None))(\n", + " Y[..., :3].reshape(-1, 3), sig, out\n", + " )\n", + " @ \"X\"\n", + " )\n", " X = X.reshape(Y.shape)\n", "\n", " return X\n", @@ -185,18 +195,17 @@ "source": [ "model = make_simple_sensor_model(5.0)\n", "\n", - "Y= jnp.array([\n", - " [\n", - " [0,0,1],\n", - " [0,1,2],\n", - " ],\n", + "Y = jnp.array(\n", " [\n", - " [1,1,3],\n", - " [1,0,4]\n", + " [\n", + " [0, 0, 1],\n", + " [0, 1, 2],\n", + " ],\n", + " [[1, 1, 3], [1, 0, 4]],\n", " ]\n", - "])\n", + ")\n", "Y.shape\n", - "Y[...,2]" + "Y[..., 2]" ] }, { @@ -221,7 +230,7 @@ ], "source": [ "key = keysplit(key)\n", - "model(Y,0.1,0.1)(key)" + "model(Y, 0.1, 0.1)(key)" ] }, { @@ -237,43 +246,52 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def make_simple_step_sensor_model(far):\n", - " \"\"\"Returns an simple step function sensor model marginalized over outliers.\"\"\" \n", + " \"\"\"Returns an simple step function sensor model marginalized over outliers.\"\"\"\n", "\n", " @genjax.drop_arguments\n", " @genjax.gen\n", " def _sensor_model_pixel(y, sig, out):\n", - " \n", - "\n", " # Compute max range along ray ending at far plane\n", - " r_ = jnp.linalg.norm(y)\n", - " rmax = r_/y[2]*far \n", + " r_ = jnp.linalg.norm(y)\n", + " rmax = r_ / y[2] * far\n", "\n", " inlier_outlier_mix = genjax.tfp_mixture(\n", - " genjax.tfp_categorical, \n", - " [genjax.tfp_uniform, genjax.tfp_uniform])\n", + " genjax.tfp_categorical, [genjax.tfp_uniform, genjax.tfp_uniform]\n", + " )\n", "\n", " # The `1e-4` term helps with numerical issues from computing rmax\n", " # at least that's what I think\n", - " r = inlier_outlier_mix(\n", - " [jnp.log(1 - out), jnp.log(out)], \n", - " ((jnp.maximum(r_-sig, 0.0) , jnp.minimum(r_+sig, rmax)), (0.0, rmax + 1e-4))) @ \"measurement\"\n", + " r = (\n", + " inlier_outlier_mix(\n", + " [jnp.log(1 - out), jnp.log(out)],\n", + " (\n", + " (jnp.maximum(r_ - sig, 0.0), jnp.minimum(r_ + sig, rmax)),\n", + " (0.0, rmax + 1e-4),\n", + " ),\n", + " )\n", + " @ \"measurement\"\n", + " )\n", "\n", " r = jnp.clip(r, 0.0, rmax)\n", "\n", - " return r * y/r_\n", + " return r * y / r_\n", "\n", - " \n", " @genjax.gen\n", " def sensor_model(Y, sig, out):\n", " \"\"\"\n", - " Simplest sensor model that returns a vector of range measurements conditioned on \n", + " Simplest sensor model that returns a vector of range measurements conditioned on\n", " an image, noise level, and outlier probability.\n", " \"\"\"\n", - " \n", - " X = genjax.Map(_sensor_model_pixel, (0,None,None))(Y[...,:3].reshape(-1,3), sig, out) @ \"X\"\n", - " X = X.reshape(Y[...,:3].shape)\n", + "\n", + " X = (\n", + " genjax.Map(_sensor_model_pixel, (0, None, None))(\n", + " Y[..., :3].reshape(-1, 3), sig, out\n", + " )\n", + " @ \"X\"\n", + " )\n", + " X = X.reshape(Y[..., :3].shape)\n", "\n", " return X\n", "\n", @@ -311,24 +329,23 @@ "model = make_simple_step_sensor_model(zmax)\n", "\n", "\n", - "Y= jnp.array([\n", + "Y = jnp.array(\n", " [\n", - " [0,0,1],\n", - " [0,1,2],\n", - " ],\n", - " [\n", - " [1,1,3],\n", - " [1,0,4]\n", + " [\n", + " [0, 0, 1],\n", + " [0, 1, 2],\n", + " ],\n", + " [[1, 1, 3], [1, 0, 4]],\n", " ]\n", - "])\n", + ")\n", "\n", "key = keysplit(key)\n", - "X = model(Y, 0.1 , 0.2)(key)\n", + "X = model(Y, 0.1, 0.2)(key)\n", "\n", "\n", - "fig, axs = plt.subplots(1,2, figsize=(8,4))\n", - "axs[0].plot(X[...,2].ravel(), marker=\"s\")\n", - "axs[0].plot(Y[...,2].ravel(), marker=\"s\")" + "fig, axs = plt.subplots(1, 2, figsize=(8, 4))\n", + "axs[0].plot(X[..., 2].ravel(), marker=\"s\")\n", + "axs[0].plot(Y[..., 2].ravel(), marker=\"s\")" ] }, { @@ -344,22 +361,27 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from genjax._src.generative_functions.distributions.distribution import ExactDensity\n", "\n", + "\n", "def wrap_into_dist(score_func):\n", " \"\"\"\n", - " Takes a scoring function \n", + " Takes a scoring function\n", "\n", - " `score_func(observed, latent, ...)` \n", + " `score_func(observed, latent, ...)`\n", "\n", " and wraps it into a genjax distribution.\n", " \"\"\"\n", + "\n", " class WrappedScoreFunc(ExactDensity):\n", - " def sample(self, key, latent, *args): return latent\n", - " def logpdf(self, observed, latent, *args): return score_func(observed, latent, *args)\n", + " def sample(self, key, latent, *args):\n", + " return latent\n", + "\n", + " def logpdf(self, observed, latent, *args):\n", + " return score_func(observed, latent, *args)\n", "\n", - " return WrappedScoreFunc()\n" + " return WrappedScoreFunc()" ] }, { diff --git a/scripts/_mkl/notebooks/30 - Table Scene Model.ipynb b/scripts/_mkl/notebooks/30 - Table Scene Model.ipynb index 402e3ab1..c544faaa 100644 --- a/scripts/_mkl/notebooks/30 - Table Scene Model.ipynb +++ b/scripts/_mkl/notebooks/30 - Table Scene Model.ipynb @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp table_scene_model" + "# |default_exp table_scene_model" ] }, { @@ -34,7 +34,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import bayes3d as b3d\n", "import bayes3d.genjax\n", "import joblib\n", @@ -58,16 +58,16 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from jax.scipy.spatial.transform import Rotation\n", "from scipy.stats import truncnorm as scipy_truncnormal\n", "\n", - "normal_logpdf = jax.scipy.stats.norm.logpdf\n", - "normal_pdf = jax.scipy.stats.norm.pdf\n", + "normal_logpdf = jax.scipy.stats.norm.logpdf\n", + "normal_pdf = jax.scipy.stats.norm.pdf\n", "truncnorm_logpdf = jax.scipy.stats.truncnorm.logpdf\n", - "truncnorm_pdf = jax.scipy.stats.truncnorm.pdf\n", + "truncnorm_pdf = jax.scipy.stats.truncnorm.pdf\n", "\n", - "inv = jnp.linalg.inv\n", + "inv = jnp.linalg.inv\n", "logaddexp = jnp.logaddexp\n", "logsumexp = jax.scipy.special.logsumexp\n", "\n", @@ -80,7 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "from bayes3d._mkl.utils import keysplit\n", "from bayes3d._mkl.plotting import *" ] @@ -98,14 +98,16 @@ "metadata": {}, "outputs": [], "source": [ - "_scaling = 1e-3\n", - "model_dir = os.path.join(b3d.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "_scaling = 1e-3\n", + "model_dir = os.path.join(b3d.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", " mesh = trimesh.load(mesh_path)\n", - " mesh.vertices *= _scaling \n", + " mesh.vertices *= _scaling\n", " meshes.append(mesh)\n", "\n", "\n", @@ -113,7 +115,7 @@ "mesh_path = os.path.join(b3d.utils.get_assets_dir(), \"sample_objs/cube.obj\")\n", "mesh = trimesh.load(mesh_path)\n", "mesh.vertices *= 1e-9\n", - "meshes.append(mesh)\n" + "meshes.append(mesh)" ] }, { @@ -123,27 +125,34 @@ "outputs": [], "source": [ "# Set up the renderer and add the scene mesh\n", - "def make_render_function(meshes, w=100, h=100, fx=30, fy=30, offx=-0.5, offy=-0.5, far=20, near=0.01):\n", + "def make_render_function(\n", + " meshes, w=100, h=100, fx=30, fy=30, offx=-0.5, offy=-0.5, far=20, near=0.01\n", + "):\n", " \"\"\"\n", - " Create a render function from a list of meshes \n", + " Create a render function from a list of meshes\n", " (and camera intrinsics).\n", " \"\"\"\n", " intrinsics = b3d.Intrinsics(\n", - " width = w, height = h,\n", - " fx = fx, fy = fy,\n", - " cx = w/2 + offx, cy = h/2 + offy,\n", - " near = near, far = far\n", + " width=w,\n", + " height=h,\n", + " fx=fx,\n", + " fy=fy,\n", + " cx=w / 2 + offx,\n", + " cy=h / 2 + offy,\n", + " near=near,\n", + " far=far,\n", " )\n", - " \n", + "\n", " b3d.setup_renderer(intrinsics)\n", - " for mesh in meshes: b3d.RENDERER.add_mesh(mesh, center_mesh=True)\n", + " for mesh in meshes:\n", + " b3d.RENDERER.add_mesh(mesh, center_mesh=True)\n", "\n", - " def render(cam:\"Camera Pose\", ps:\"Object Poses\", inds:\"Object indices\"):\n", + " def render(cam: \"Camera Pose\", ps: \"Object Poses\", inds: \"Object indices\"):\n", " \"\"\"\n", - " Returns image of shape `(h, w, 4)` where the first 3 channels encode \n", + " Returns image of shape `(h, w, 4)` where the first 3 channels encode\n", " xyz-coordinates and the last channel encodes semantic information.\n", " \"\"\"\n", - " return b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + " return b3d.RENDERER.render(inv(cam) @ ps, inds)\n", "\n", " return render" ] @@ -175,12 +184,12 @@ } ], "source": [ - "_far = 5.0\n", - "_shape = (200,200)\n", - "_f = 300\n", + "_far = 5.0\n", + "_shape = (200, 200)\n", + "_f = 300\n", "\n", - "_intr = dict(w=_shape[1], h=_shape[0], fx=_f, fy=_f, near=1e-4, far=_far)\n", - "render = make_render_function(meshes, **_intr)\n", + "_intr = dict(w=_shape[1], h=_shape[0], fx=_f, fy=_f, near=1e-4, far=_far)\n", + "render = make_render_function(meshes, **_intr)\n", "\n", "help(render)" ] @@ -192,7 +201,7 @@ "outputs": [], "source": [ "def prep_im(Y, far=5.0, eps=1e-6):\n", - " im = np.where(Y[:,:,2]>= far - eps, jnp.inf, Y[:,:,2])\n", + " im = np.where(Y[:, :, 2] >= far - eps, jnp.inf, Y[:, :, 2])\n", " return im" ] }, @@ -209,7 +218,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def make_table_scene_model():\n", " \"\"\"\n", " Example:\n", @@ -221,13 +230,13 @@ "\n", " table = jnp.eye(4)\n", " cam = b3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, -.5, -.75]), \n", - " jnp.zeros(3), \n", + " jnp.array([0.0, -.5, -.75]),\n", + " jnp.zeros(3),\n", " jnp.array([0.0,-1.0,0.0]))\n", "\n", " args = (\n", - " jnp.arange(3), \n", - " jnp.arange(22), \n", + " jnp.arange(3),\n", + " jnp.arange(22),\n", " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", " b3d.RENDERER.model_box_dims\n", @@ -258,40 +267,45 @@ " \"\"\"\n", "\n", " @genjax.gen\n", - " def model(nums, \n", - " possible_object_indices, \n", - " pose_bounds, \n", - " contact_bounds, \n", - " all_box_dims):\n", - " \n", - " num_objects = len(nums) # this is a hack, otherwise genajx is complaining\n", - "\n", - " indices = jnp.array([], dtype=jnp.int32)\n", - " root_poses = jnp.zeros((0,4,4))\n", - " contact_params = jnp.zeros((0,3))\n", - " faces_parents = jnp.array([], dtype=jnp.int32)\n", - " faces_child = jnp.array([], dtype=jnp.int32)\n", - " parents = jnp.array([], dtype=jnp.int32)\n", - "\n", - " for i in range(num_objects):\n", + " def model(nums, possible_object_indices, pose_bounds, contact_bounds, all_box_dims):\n", + " num_objects = len(nums) # this is a hack, otherwise genajx is complaining\n", "\n", - " index = uniform_discrete(possible_object_indices) @ f\"id_{i}\"\n", - " pose = uniform_pose(pose_bounds[0], pose_bounds[1]) @ f\"root_pose_{i}\"\n", - " params = contact_params_uniform(contact_bounds[0], contact_bounds[1]) @ f\"contact_params_{i}\"\n", + " indices = jnp.array([], dtype=jnp.int32)\n", + " root_poses = jnp.zeros((0, 4, 4))\n", + " contact_params = jnp.zeros((0, 3))\n", + " faces_parents = jnp.array([], dtype=jnp.int32)\n", + " faces_child = jnp.array([], dtype=jnp.int32)\n", + " parents = jnp.array([], dtype=jnp.int32)\n", "\n", - " parent_obj = uniform_discrete(jnp.arange(-1, num_objects - 1)) @ f\"parent_{i}\"\n", - " parent_face = uniform_discrete(jnp.arange(0,6)) @ f\"face_parent_{i}\"\n", - " child_face = uniform_discrete(jnp.arange(0,6)) @ f\"face_child_{i}\"\n", - "\n", - " indices = jnp.concatenate([indices, jnp.array([index])])\n", - " root_poses = jnp.concatenate([root_poses, pose.reshape(1,4,4)])\n", - " contact_params = jnp.concatenate([contact_params, params.reshape(1,-1)])\n", - " parents = jnp.concatenate([parents, jnp.array([parent_obj])])\n", - " faces_parents = jnp.concatenate([faces_parents, jnp.array([parent_face])])\n", - " faces_child = jnp.concatenate([faces_child, jnp.array([child_face])])\n", - " \n", - "\n", - " scene = (root_poses, all_box_dims[indices], parents, contact_params, faces_parents, faces_child)\n", + " for i in range(num_objects):\n", + " index = uniform_discrete(possible_object_indices) @ f\"id_{i}\"\n", + " pose = uniform_pose(pose_bounds[0], pose_bounds[1]) @ f\"root_pose_{i}\"\n", + " params = (\n", + " contact_params_uniform(contact_bounds[0], contact_bounds[1])\n", + " @ f\"contact_params_{i}\"\n", + " )\n", + "\n", + " parent_obj = (\n", + " uniform_discrete(jnp.arange(-1, num_objects - 1)) @ f\"parent_{i}\"\n", + " )\n", + " parent_face = uniform_discrete(jnp.arange(0, 6)) @ f\"face_parent_{i}\"\n", + " child_face = uniform_discrete(jnp.arange(0, 6)) @ f\"face_child_{i}\"\n", + "\n", + " indices = jnp.concatenate([indices, jnp.array([index])])\n", + " root_poses = jnp.concatenate([root_poses, pose.reshape(1, 4, 4)])\n", + " contact_params = jnp.concatenate([contact_params, params.reshape(1, -1)])\n", + " parents = jnp.concatenate([parents, jnp.array([parent_obj])])\n", + " faces_parents = jnp.concatenate([faces_parents, jnp.array([parent_face])])\n", + " faces_child = jnp.concatenate([faces_child, jnp.array([child_face])])\n", + "\n", + " scene = (\n", + " root_poses,\n", + " all_box_dims[indices],\n", + " parents,\n", + " contact_params,\n", + " faces_parents,\n", + " faces_child,\n", + " )\n", " poses = b.scene_graph.poses_from_scene_graph(*scene)\n", "\n", " camera_pose = uniform_pose(pose_bounds[0], pose_bounds[1]) @ f\"camera_pose\"\n", @@ -353,36 +367,40 @@ "source": [ "key = keysplit(key)\n", "\n", - "cam_x = jnp.array([0.0, -.5, -.75])\n", - "cam = b3d.transform_from_pos_target_up(cam_x, jnp.zeros(3), jnp.array([0.0,-1.0,0.0]))\n", + "cam_x = jnp.array([0.0, -0.5, -0.75])\n", + "cam = b3d.transform_from_pos_target_up(cam_x, jnp.zeros(3), jnp.array([0.0, -1.0, 0.0]))\n", "\n", "table = jnp.eye(4)\n", "\n", "args = (\n", " jnp.arange(3),\n", " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b3d.RENDERER.model_box_dims\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b3d.RENDERER.model_box_dims,\n", ")\n", "\n", - "ch = genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"parent_2\": 0,\n", - " \"camera_pose\": cam,\n", - " \"root_pose_0\": table,\n", - " \"id_0\": jnp.int32(21), # Atomic Table\n", - " \"id_1\": jnp.int32(13), # Mug\n", - " \"id_2\": jnp.int32(2), # Box\n", - " \"face_parent_1\": 1, # That's the top face of the table\n", - " \"face_parent_2\": 1, # ...\n", - " \"face_child_1\": 3, # That's a bottom face of the mug\n", - " \"face_child_2\": 3,\n", - "})\n", - "\n", - "\n", - "w, tr = model.importance(key, ch , args)\n", + "ch = genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"parent_2\": 0,\n", + " \"camera_pose\": cam,\n", + " \"root_pose_0\": table,\n", + " \"id_0\": jnp.int32(21), # Atomic Table\n", + " \"id_1\": jnp.int32(13), # Mug\n", + " \"id_2\": jnp.int32(2), # Box\n", + " \"face_parent_1\": 1, # That's the top face of the table\n", + " \"face_parent_2\": 1, # ...\n", + " \"face_child_1\": 3, # That's a bottom face of the mug\n", + " \"face_child_2\": 3,\n", + " }\n", + ")\n", + "\n", + "\n", + "w, tr = model.importance(key, ch, args)\n", "cam, ps, inds = tr.retval\n", "X = render(cam, ps, inds)\n", "\n", @@ -398,13 +416,16 @@ "outputs": [], "source": [ "fname = f\"table_scene_3\"\n", - "jnp.save(f\"data/likelihood_test/{fname}\", dict(\n", - " key = key,\n", - " args = tr.args,\n", - " choices = tr.strip(), \n", - " intrinsics = _intr,\n", - " rendered = X,\n", - "))" + "jnp.save(\n", + " f\"data/likelihood_test/{fname}\",\n", + " dict(\n", + " key=key,\n", + " args=tr.args,\n", + " choices=tr.strip(),\n", + " intrinsics=_intr,\n", + " rendered=X,\n", + " ),\n", + ")" ] }, { @@ -415,7 +436,7 @@ "source": [ "fname = f\"table_scene_3\"\n", "arr = jnp.load(f\"data/likelihood_test/{fname}.npy\", allow_pickle=True)\n", - "ch = arr.item()[\"choices\"]\n" + "ch = arr.item()[\"choices\"]" ] }, { @@ -445,7 +466,7 @@ } ], "source": [ - "_, tr_ = model.importance(key, ch , args)\n", + "_, tr_ = model.importance(key, ch, args)\n", "cam, ps, inds = tr_.retval\n", "X = render(cam, ps, inds)\n", "\n", @@ -468,34 +489,35 @@ "source": [ "from bayes3d._mkl.pose import pack_pose\n", "\n", + "\n", "def generic_viewpoint(key, cam, n, sig_x, sig_hd):\n", " \"\"\"Generates generix camera poses by varying its xy-coordinates and angle (in the xy-plane).\"\"\"\n", - " \n", + "\n", " # TODO: Make a version that varies rot and pitch and potentially roll.\n", - " \n", - " _, keys = keysplit(key,1,2)\n", + "\n", + " _, keys = keysplit(key, 1, 2)\n", "\n", " # Generic position\n", - " xs = sig_x*jax.random.normal(keys[1], (n,3))\n", - " xs = xs.at[0,:].set(0.0)\n", - " xs = xs.at[:,2].set(0.0)\n", + " xs = sig_x * jax.random.normal(keys[1], (n, 3))\n", + " xs = xs.at[0, :].set(0.0)\n", + " xs = xs.at[:, 2].set(0.0)\n", "\n", " # Generic rotation\n", - " hds = sig_hd*jax.random.normal(keys[0], (n,))\n", + " hds = sig_hd * jax.random.normal(keys[0], (n,))\n", " hds = hds.at[0].set(0.0)\n", - " rs = vmap(Rotation.from_euler, (None,0))(\"y\", hds)\n", + " rs = vmap(Rotation.from_euler, (None, 0))(\"y\", hds)\n", " rs = Rotation.as_matrix(rs)\n", - " \n", + "\n", " # Generic camera poses\n", " ps = vmap(pack_pose)(xs, rs)\n", - " ps = cam@ps\n", + " ps = cam @ ps\n", "\n", " # Generic weights\n", " logps_hd = normal_logpdf(hds, loc=0.0, scale=sig_hd)\n", - " logps_x = normal_logpdf( xs, loc=0.0, scale=sig_x).sum(-1)\n", - " logps = logps_hd + logps_x\n", + " logps_x = normal_logpdf(xs, loc=0.0, scale=sig_x).sum(-1)\n", + " logps = logps_hd + logps_x\n", "\n", - " return ps, logps\n" + " return ps, logps" ] }, { @@ -505,33 +527,31 @@ "outputs": [], "source": [ "def generic_contact(key, p0, n, sig_x, sig_hd):\n", - "\n", - " _, keys = keysplit(key,1,2)\n", + " _, keys = keysplit(key, 1, 2)\n", "\n", " # Generic contact-pose vector\n", - " xs = sig_x*jax.random.normal(keys[1], (n,3))\n", - " xs = xs.at[:,2].set(0.0)\n", - " xs = xs.at[0,:].set(0.0)\n", + " xs = sig_x * jax.random.normal(keys[1], (n, 3))\n", + " xs = xs.at[:, 2].set(0.0)\n", + " xs = xs.at[0, :].set(0.0)\n", "\n", - " hds = sig_hd*jax.random.normal(keys[0], (n,1))\n", - " hds = hds.at[0,:].set(0.0)\n", - " rs = vmap(Rotation.from_euler, (None,0))(\"z\", hds)\n", + " hds = sig_hd * jax.random.normal(keys[0], (n, 1))\n", + " hds = hds.at[0, :].set(0.0)\n", + " rs = vmap(Rotation.from_euler, (None, 0))(\"z\", hds)\n", " rs = Rotation.as_matrix(rs)\n", - " \n", + "\n", " # Generic camera poses\n", " ps = vmap(pack_pose)(xs, rs)\n", " # vs = jnp.concatenate([xs, hds], axis=1)\n", "\n", " # Generic weights\n", - " logps_hd = normal_logpdf(hds[:,0], loc=0.0, scale=sig_hd)\n", - " logps_x = normal_logpdf (xs, loc=0.0, scale=sig_x).sum(-1)\n", - " logps = logps_hd + logps_x\n", + " logps_hd = normal_logpdf(hds[:, 0], loc=0.0, scale=sig_hd)\n", + " logps_x = normal_logpdf(xs, loc=0.0, scale=sig_x).sum(-1)\n", + " logps = logps_hd + logps_x\n", "\n", " # Generic object pose\n", - " generic_ps = p0@ps\n", + " generic_ps = p0 @ ps\n", "\n", - " return generic_ps, logps\n", - "\n" + " return generic_ps, logps" ] }, { @@ -557,11 +577,11 @@ "ws -= logsumexp(ws)\n", "\n", "\n", - "Ys = vmap(render, (None,0,None))(cam, generic_ps[:,None], inds[1][None])\n", + "Ys = vmap(render, (None, 0, None))(cam, generic_ps[:, None], inds[1][None])\n", "\n", "\n", "# ==============================\n", - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "plt.title(\"Generic Viewpoint\")\n", "for Y in Ys[:]:\n", " plt.imshow(prep_im(Y), alpha=0.1)" @@ -575,11 +595,12 @@ "source": [ "def get_linear_grid(shape, bounds, flat=False):\n", " \"\"\"Create a (linear) grid of a given shape and bounds.\"\"\"\n", - " \n", - " linspaces = [jnp.linspace(*b, d) for b,d in zip(bounds, shape)]\n", - " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing='ij'), axis=-1)\n", - " if flat: vs = vs.reshape(-1,len(shape))\n", - " \n", + "\n", + " linspaces = [jnp.linspace(*b, d) for b, d in zip(bounds, shape)]\n", + " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing=\"ij\"), axis=-1)\n", + " if flat:\n", + " vs = vs.reshape(-1, len(shape))\n", + "\n", " return vs" ] }, @@ -600,15 +621,15 @@ } ], "source": [ - "vs = get_linear_grid((5, 5, 5), ((-2,2), (-2,2), (0,jnp.pi/3)), flat=True)\n", + "vs = get_linear_grid((5, 5, 5), ((-2, 2), (-2, 2), (0, jnp.pi / 3)), flat=True)\n", "sc = jnp.arange(len(vs))\n", "\n", "# =======================================\n", - "fig, ax = plt.subplots(1,1,figsize=(3,3))\n", - "zoom_in(vs[:,:2], 2, ax=ax)\n", + "fig, ax = plt.subplots(1, 1, figsize=(3, 3))\n", + "zoom_in(vs[:, :2], 2, ax=ax)\n", "plot_poses(vs, sc, linewidth=1, ax=ax)\n", - "ax.spines['top'].set_visible(False)\n", - "ax.spines['right'].set_visible(False)" + "ax.spines[\"top\"].set_visible(False)\n", + "ax.spines[\"right\"].set_visible(False)" ] }, { @@ -618,7 +639,7 @@ "outputs": [], "source": [ "def _contact_from_grid(v, p0=jnp.eye(4), sig_x=1.0, sig_hd=1.0):\n", - " x = jnp.array([*v[:2],0.0])\n", + " x = jnp.array([*v[:2], 0.0])\n", " hd = v[2]\n", "\n", " r = Rotation.from_euler(\"z\", hd)\n", @@ -626,12 +647,13 @@ " p = pack_pose(x, r)\n", "\n", " logp_hd = normal_logpdf(hd, loc=0.0, scale=sig_hd)\n", - " logp_x = normal_logpdf (x, loc=0.0, scale=sig_x).sum(-1)\n", - " logp = logp_hd + logp_x\n", + " logp_x = normal_logpdf(x, loc=0.0, scale=sig_x).sum(-1)\n", + " logp = logp_hd + logp_x\n", + "\n", + " return p0 @ p, logp\n", "\n", - " return p0@p, logp\n", "\n", - "contact_from_grid = vmap(_contact_from_grid, (0,None,None,None))" + "contact_from_grid = vmap(_contact_from_grid, (0, None, None, None))" ] }, { @@ -640,14 +662,14 @@ "metadata": {}, "outputs": [], "source": [ - "dx = 0.1\n", - "dhd = jnp.pi/4\n", + "dx = 0.1\n", + "dhd = jnp.pi / 4\n", "sig_hd = dhd\n", - "sig_x = dx\n", + "sig_x = dx\n", "\n", "ley = keysplit(key)\n", - "v0 = 0.01*jax.random.normal(key, (3,))\n", - "vs = get_linear_grid((10, 10, 10), ((-dx,dx), (-dx,dx), (-dhd,dhd)), flat=True)\n", + "v0 = 0.01 * jax.random.normal(key, (3,))\n", + "vs = get_linear_grid((10, 10, 10), ((-dx, dx), (-dx, dx), (-dhd, dhd)), flat=True)\n", "vs += v0\n", "\n", "\n", @@ -671,14 +693,14 @@ } ], "source": [ - "sc = logps\n", + "sc = logps\n", "\n", "# =======================================\n", - "fig, ax = plt.subplots(1,1,figsize=(3,3))\n", - "zoom_in(vs[:,:2], 0.1, ax=ax)\n", + "fig, ax = plt.subplots(1, 1, figsize=(3, 3))\n", + "zoom_in(vs[:, :2], 0.1, ax=ax)\n", "plot_poses(vs, sc, r=0.025, linewidth=1, ax=ax)\n", - "ax.spines['top'].set_visible(False)\n", - "ax.spines['right'].set_visible(False)" + "ax.spines[\"top\"].set_visible(False)\n", + "ax.spines[\"right\"].set_visible(False)" ] }, { @@ -687,7 +709,7 @@ "metadata": {}, "outputs": [], "source": [ - "render_generic = jit(vmap(render, (None,0,None)))" + "render_generic = jit(vmap(render, (None, 0, None)))" ] }, { @@ -707,13 +729,13 @@ } ], "source": [ - "Ys = vmap(render, (None,0,None))(cam, generic_ps[:,None], inds[1][None])\n", + "Ys = vmap(render, (None, 0, None))(cam, generic_ps[:, None], inds[1][None])\n", "\n", "# ==============================\n", - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "plt.title(\"Generic Viewpoint\")\n", "for Y in Ys[:]:\n", - " plt.imshow(prep_im(Y), alpha=.1)" + " plt.imshow(prep_im(Y), alpha=0.1)" ] }, { diff --git a/scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb b/scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb index 80b3268e..51000ebc 100644 --- a/scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb +++ b/scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|default_exp generic" + "# |default_exp generic" ] }, { @@ -15,7 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "import jax\n", "from bayes3d._mkl.utils import keysplit\n", "from bayes3d._mkl.pose import pack_pose\n", @@ -24,12 +24,12 @@ "from jax.scipy.spatial.transform import Rotation\n", "from scipy.stats import truncnorm as scipy_truncnormal\n", "\n", - "normal_logpdf = jax.scipy.stats.norm.logpdf\n", - "normal_pdf = jax.scipy.stats.norm.pdf\n", + "normal_logpdf = jax.scipy.stats.norm.logpdf\n", + "normal_pdf = jax.scipy.stats.norm.pdf\n", "truncnorm_logpdf = jax.scipy.stats.truncnorm.logpdf\n", - "truncnorm_pdf = jax.scipy.stats.truncnorm.pdf\n", + "truncnorm_pdf = jax.scipy.stats.truncnorm.pdf\n", "\n", - "inv = jnp.linalg.inv\n", + "inv = jnp.linalg.inv\n", "logaddexp = jnp.logaddexp\n", "logsumexp = jax.scipy.special.logsumexp" ] @@ -40,35 +40,35 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def generic_viewpoint(key, cam, n, sig_x, sig_hd):\n", " \"\"\"Generates generix camera poses by varying its xy-coordinates and angle (in the xy-plane).\"\"\"\n", - " \n", + "\n", " # TODO: Make a version that varies rot and pitch and potentially roll.\n", - " \n", - " _, keys = keysplit(key,1,2)\n", + "\n", + " _, keys = keysplit(key, 1, 2)\n", "\n", " # Generic position\n", - " xs = sig_x*jax.random.normal(keys[1], (n,3))\n", - " xs = xs.at[0,:].set(0.0)\n", - " xs = xs.at[:,2].set(0.0)\n", + " xs = sig_x * jax.random.normal(keys[1], (n, 3))\n", + " xs = xs.at[0, :].set(0.0)\n", + " xs = xs.at[:, 2].set(0.0)\n", "\n", " # Generic rotation\n", - " hds = sig_hd*jax.random.normal(keys[0], (n,))\n", + " hds = sig_hd * jax.random.normal(keys[0], (n,))\n", " hds = hds.at[0].set(0.0)\n", - " rs = vmap(Rotation.from_euler, (None,0))(\"y\", hds)\n", + " rs = vmap(Rotation.from_euler, (None, 0))(\"y\", hds)\n", " rs = Rotation.as_matrix(rs)\n", - " \n", + "\n", " # Generic camera poses\n", " ps = vmap(pack_pose)(xs, rs)\n", - " ps = cam@ps\n", + " ps = cam @ ps\n", "\n", " # Generic weights\n", " logps_hd = normal_logpdf(hds, loc=0.0, scale=sig_hd)\n", - " logps_x = normal_logpdf( xs, loc=0.0, scale=sig_x).sum(-1)\n", - " logps = logps_hd + logps_x\n", + " logps_x = normal_logpdf(xs, loc=0.0, scale=sig_x).sum(-1)\n", + " logps = logps_hd + logps_x\n", "\n", - " return ps, logps\n" + " return ps, logps" ] }, { @@ -77,35 +77,33 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def generic_contact(key, p0, n, sig_x, sig_hd):\n", - "\n", - " _, keys = keysplit(key,1,2)\n", + " _, keys = keysplit(key, 1, 2)\n", "\n", " # Generic contact-pose vector\n", - " xs = sig_x*jax.random.normal(keys[1], (n,3))\n", - " xs = xs.at[:,2].set(0.0)\n", - " xs = xs.at[0,:].set(0.0)\n", + " xs = sig_x * jax.random.normal(keys[1], (n, 3))\n", + " xs = xs.at[:, 2].set(0.0)\n", + " xs = xs.at[0, :].set(0.0)\n", "\n", - " hds = sig_hd*jax.random.normal(keys[0], (n,1))\n", - " hds = hds.at[0,:].set(0.0)\n", - " rs = vmap(Rotation.from_euler, (None,0))(\"z\", hds)\n", + " hds = sig_hd * jax.random.normal(keys[0], (n, 1))\n", + " hds = hds.at[0, :].set(0.0)\n", + " rs = vmap(Rotation.from_euler, (None, 0))(\"z\", hds)\n", " rs = Rotation.as_matrix(rs)\n", - " \n", + "\n", " # Generic camera poses\n", " ps = vmap(pack_pose)(xs, rs)\n", " # vs = jnp.concatenate([xs, hds], axis=1)\n", "\n", " # Generic weights\n", - " logps_hd = normal_logpdf(hds[:,0], loc=0.0, scale=sig_hd)\n", - " logps_x = normal_logpdf (xs, loc=0.0, scale=sig_x).sum(-1)\n", - " logps = logps_hd + logps_x\n", + " logps_hd = normal_logpdf(hds[:, 0], loc=0.0, scale=sig_hd)\n", + " logps_x = normal_logpdf(xs, loc=0.0, scale=sig_x).sum(-1)\n", + " logps = logps_hd + logps_x\n", "\n", " # Generic object pose\n", - " generic_ps = p0@ps\n", + " generic_ps = p0 @ ps\n", "\n", - " return generic_ps, logps\n", - "\n" + " return generic_ps, logps" ] }, { @@ -114,9 +112,9 @@ "metadata": {}, "outputs": [], "source": [ - "#|export\n", + "# |export\n", "def _contact_from_grid(v, p0=jnp.eye(4), sig_x=1.0, sig_hd=1.0):\n", - " x = jnp.array([*v[:2],0.0])\n", + " x = jnp.array([*v[:2], 0.0])\n", " hd = v[2]\n", "\n", " r = Rotation.from_euler(\"z\", hd)\n", @@ -124,12 +122,13 @@ " p = pack_pose(x, r)\n", "\n", " logp_hd = normal_logpdf(hd, loc=0.0, scale=sig_hd)\n", - " logp_x = normal_logpdf (x, loc=0.0, scale=sig_x).sum(-1)\n", - " logp = logp_hd + logp_x\n", + " logp_x = normal_logpdf(x, loc=0.0, scale=sig_x).sum(-1)\n", + " logp = logp_hd + logp_x\n", + "\n", + " return p0 @ p, logp\n", "\n", - " return p0@p, logp\n", "\n", - "contact_from_grid = vmap(_contact_from_grid, (0,None,None,None))" + "contact_from_grid = vmap(_contact_from_grid, (0, None, None, None))" ] }, { diff --git a/scripts/_mkl/notebooks/32 - Patch model.ipynb b/scripts/_mkl/notebooks/32 - Patch model.ipynb index 2c853e5d..2718b194 100644 --- a/scripts/_mkl/notebooks/32 - Patch model.ipynb +++ b/scripts/_mkl/notebooks/32 - Patch model.ipynb @@ -60,14 +60,16 @@ "metadata": {}, "outputs": [], "source": [ - "_scaling = 1e-3\n", - "model_dir = os.path.join(b3d.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "_scaling = 1e-3\n", + "model_dir = os.path.join(b3d.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", " mesh = trimesh.load(mesh_path)\n", - " mesh.vertices *= _scaling \n", + " mesh.vertices *= _scaling\n", " meshes.append(mesh)\n", "\n", "\n", @@ -75,7 +77,7 @@ "mesh_path = os.path.join(b3d.utils.get_assets_dir(), \"sample_objs/cube.obj\")\n", "mesh = trimesh.load(mesh_path)\n", "mesh.vertices *= 1e-9\n", - "meshes.append(mesh)\n" + "meshes.append(mesh)" ] }, { @@ -103,13 +105,11 @@ "h = 200\n", "f = 300\n", "intr = b3d.Intrinsics(\n", - " width = w, height = h,\n", - " fx = f, fy = f,\n", - " cx = w/2 - 0.5, cy = h/2 - 0.5,\n", - " near = 1e-4, far = 5.0\n", + " width=w, height=h, fx=f, fy=f, cx=w / 2 - 0.5, cy=h / 2 - 0.5, near=1e-4, far=5.0\n", ")\n", "b3d.setup_renderer(intr)\n", - "for mesh in meshes: b3d.RENDERER.add_mesh(mesh, center_mesh=True)" + "for mesh in meshes:\n", + " b3d.RENDERER.add_mesh(mesh, center_mesh=True)" ] }, { @@ -119,7 +119,7 @@ "outputs": [], "source": [ "def prep_im(Y, far=5.0, eps=1e-6):\n", - " im = jnp.where(Y[:,:,2]>= far - eps, jnp.inf, Y[:,:,2])\n", + " im = jnp.where(Y[:, :, 2] >= far - eps, jnp.inf, Y[:, :, 2])\n", " return im" ] }, @@ -142,8 +142,8 @@ "data = data.item()\n", "\n", "scene_ch = data[\"choices\"]\n", - "cam = data[\"choices\"][\"camera_pose\"]\n", - "table = data[\"choices\"][\"root_pose_0\"]\n", + "cam = data[\"choices\"][\"camera_pose\"]\n", + "table = data[\"choices\"][\"root_pose_0\"]\n", "scene_model_args = data[\"args\"]" ] }, @@ -180,7 +180,7 @@ "w, tr = scene_model.importance(key, scene_ch, scene_model_args)\n", "cam, ps, inds = tr.retval\n", "\n", - "X = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + "X = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", "\n", "# =====================\n", "plt.imshow(prep_im(X))" @@ -195,7 +195,7 @@ "def get_patch(X, c, w, h):\n", " i = int(c[1])\n", " j = int(c[0])\n", - " patch = X[ i-h : i+h+1 , j-w : j+w+1 ]\n", + " patch = X[i - h : i + h + 1, j - w : j + w + 1]\n", " return patch" ] }, @@ -226,25 +226,25 @@ } ], "source": [ - "c = jnp.array([90,70])\n", - "att_w = 5\n", - "att_h = 5\n", + "c = jnp.array([90, 70])\n", + "att_w = 5\n", + "att_h = 5\n", "\n", "patch = get_patch(X, c, att_w, att_h)\n", "\n", "# =====================\n", - "fig, axs = plt.subplots(1, 2, figsize=(20,10))\n", - "axs[0].set_xlim(c[0]-att_w-20, c[0]+att_w+20)\n", - "axs[0].set_ylim(c[1]+att_h+20, c[1]-att_h-20)\n", + "fig, axs = plt.subplots(1, 2, figsize=(20, 10))\n", + "axs[0].set_xlim(c[0] - att_w - 20, c[0] + att_w + 20)\n", + "axs[0].set_ylim(c[1] + att_h + 20, c[1] - att_h - 20)\n", "axs[0].imshow(prep_im(X))\n", "axs[0].scatter(*c, c=\"r\", s=20, marker=\"x\")\n", - "axs[0].scatter(c[0]-att_w, c[1]-att_h, c=\"r\", s=10)\n", - "axs[0].scatter(c[0]-att_w, c[1]+att_h, c=\"r\", s=10)\n", - "axs[0].scatter(c[0]+att_w, c[1]-att_h, c=\"r\", s=10)\n", - "axs[0].scatter(c[0]+att_w, c[1]+att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] - att_w, c[1] - att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] - att_w, c[1] + att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] + att_w, c[1] - att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] + att_w, c[1] + att_h, c=\"r\", s=10)\n", "\n", "axs[1].imshow(prep_im(patch))\n", - "axs[1].scatter(att_w,att_h, c=\"r\", s=20, marker=\"x\")\n" + "axs[1].scatter(att_w, att_h, c=\"r\", s=20, marker=\"x\")" ] }, { @@ -279,7 +279,7 @@ "\n", "sensor_model = make_simple_sensor_model(5.0)\n", "key = keysplit(key)\n", - "X_ = sensor_model(patch[...,:3], 0.001,0.0)(key)\n", + "X_ = sensor_model(patch[..., :3], 0.001, 0.0)(key)\n", "\n", "# =====================\n", "plt.imshow(prep_im(X_, far=4.5))" @@ -302,21 +302,18 @@ "\n", "\n", "def make_patch_model(scene_model_args):\n", - "\n", - " scene_model = make_table_scene_model()\n", + " scene_model = make_table_scene_model()\n", " sensor_model = make_simple_sensor_model(5.0)\n", "\n", - "\n", " @genjax.gen\n", " def patch_model(sig, out):\n", - "\n", " cam, ps, inds = scene_model(*scene_model_args) @ \"scene\"\n", - " Y = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", - " X = sensor_model.inline(Y[...,:3], sig, out)\n", + " Y = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", + " X = sensor_model.inline(Y[..., :3], sig, out)\n", "\n", - " return X,Y\n", + " return X, Y\n", "\n", - " return patch_model\n" + " return patch_model" ] }, { @@ -341,19 +338,23 @@ ], "source": [ "w_orig = h_orig = 200\n", - "w_att = h_att = 2\n", - "fx = fy = 300\n", - "center = jnp.array([90., 70.])\n", + "w_att = h_att = 2\n", + "fx = fy = 300\n", + "center = jnp.array([90.0, 70.0])\n", "\n", "intr = b3d.Intrinsics(\n", - " width = 2*w_att+1, height = 2*h_att+1,\n", - " fx = fx, fy = fy,\n", - " cx = w_orig/2 - 0.5 + w_att - center[0], cy = h_orig/2 - 0.5 + h_att - center[1],\n", - " near = 1e-4, \n", - " far = 5.0\n", + " width=2 * w_att + 1,\n", + " height=2 * h_att + 1,\n", + " fx=fx,\n", + " fy=fy,\n", + " cx=w_orig / 2 - 0.5 + w_att - center[0],\n", + " cy=h_orig / 2 - 0.5 + h_att - center[1],\n", + " near=1e-4,\n", + " far=5.0,\n", ")\n", "b3d.setup_renderer(intr)\n", - "for mesh in meshes: b3d.RENDERER.add_mesh(mesh, center_mesh=True)" + "for mesh in meshes:\n", + " b3d.RENDERER.add_mesh(mesh, center_mesh=True)" ] }, { @@ -365,9 +366,11 @@ "scene_model_args = (\n", " jnp.arange(2),\n", " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b3d.RENDERER.model_box_dims\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b3d.RENDERER.model_box_dims,\n", ")\n", "\n", "model = make_patch_model(scene_model_args)" @@ -400,26 +403,30 @@ } ], "source": [ - "ch = genjax.choice_map({\n", - " \"scene\": genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"camera_pose\": cam,\n", - " \"root_pose_0\": table,\n", - " \"id_0\": jnp.int32(21), # Atomic Table\n", - " \"id_1\": jnp.int32(13), # Mug\n", - " \"face_parent_1\": 1, # That's the top face of the table\n", - " \"face_child_1\": 3, # That's a bottom face of the mug\n", - " \"contact_params_1\": scene_ch[\"contact_params_1\"],\n", - " })\n", - "})\n", + "ch = genjax.choice_map(\n", + " {\n", + " \"scene\": genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"camera_pose\": cam,\n", + " \"root_pose_0\": table,\n", + " \"id_0\": jnp.int32(21), # Atomic Table\n", + " \"id_1\": jnp.int32(13), # Mug\n", + " \"face_parent_1\": 1, # That's the top face of the table\n", + " \"face_child_1\": 3, # That's a bottom face of the mug\n", + " \"contact_params_1\": scene_ch[\"contact_params_1\"],\n", + " }\n", + " )\n", + " }\n", + ")\n", "\n", "key = keysplit(key)\n", - "_, tr0 = model.importance(key, ch, (1e-9,0.0))\n", + "_, tr0 = model.importance(key, ch, (1e-9, 0.0))\n", "\n", "P, _ = tr0.retval\n", "# ========================\n", - "plt.figure(figsize=(1,1))\n", + "plt.figure(figsize=(1, 1))\n", "plt.imshow(prep_im(P, far=4.5))" ] }, @@ -434,11 +441,12 @@ "\n", "def get_linear_grid(shape, bounds, flat=False):\n", " \"\"\"Create a (linear) grid of a given shape and bounds.\"\"\"\n", - " \n", - " linspaces = [jnp.linspace(*b, d) for b,d in zip(bounds, shape)]\n", - " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing='ij'), axis=-1)\n", - " if flat: vs = vs.reshape(-1,len(shape))\n", - " \n", + "\n", + " linspaces = [jnp.linspace(*b, d) for b, d in zip(bounds, shape)]\n", + " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing=\"ij\"), axis=-1)\n", + " if flat:\n", + " vs = vs.reshape(-1, len(shape))\n", + "\n", " return vs" ] }, @@ -449,11 +457,15 @@ "outputs": [], "source": [ "def update(key, tr0, v):\n", - " ch = genjax.choice_map({ \"scene\": \n", - " genjax.choice_map({\n", - " \"contact_params_1\": v,\n", - " })\n", - " })\n", + " ch = genjax.choice_map(\n", + " {\n", + " \"scene\": genjax.choice_map(\n", + " {\n", + " \"contact_params_1\": v,\n", + " }\n", + " )\n", + " }\n", + " )\n", " diffs = argdiffs(tr0.args)\n", " (_, w, tr1, _) = tr0.update(key, ch, diffs)\n", " return tr1" @@ -466,18 +478,22 @@ "outputs": [], "source": [ "def _eval_contact(key, tr0, v):\n", - " ch = genjax.choice_map({ \"scene\": \n", - " genjax.choice_map({\n", - " \"contact_params_1\": v,\n", - " })\n", - " })\n", + " ch = genjax.choice_map(\n", + " {\n", + " \"scene\": genjax.choice_map(\n", + " {\n", + " \"contact_params_1\": v,\n", + " }\n", + " )\n", + " }\n", + " )\n", " diffs = argdiffs(tr0.args)\n", " (_, w, tr1, _) = tr0.update(key, ch, diffs)\n", "\n", " return tr1.get_score()\n", "\n", "\n", - "eval_contact = jit(vmap(_eval_contact, (None,None,0)))\n" + "eval_contact = jit(vmap(_eval_contact, (None, None, 0)))" ] }, { @@ -487,7 +503,7 @@ "outputs": [], "source": [ "key = keysplit(key)\n", - "_, tr0 = model.importance(key, tr0.strip(), (0.001,0.1))" + "_, tr0 = model.importance(key, tr0.strip(), (0.001, 0.1))" ] }, { @@ -526,11 +542,11 @@ } ], "source": [ - "dx = 0.05\n", + "dx = 0.05\n", "dhd = jnp.pi\n", "v0 = scene_ch[\"contact_params_1\"]\n", - "vs = get_linear_grid((65, 65, 65), ((-dx,dx), (-dx,dx), (-dhd,dhd)), flat=True)\n", - "vs += jnp.array([-0.05,0.1, 0.0])\n", + "vs = get_linear_grid((65, 65, 65), ((-dx, dx), (-dx, dx), (-dhd, dhd)), flat=True)\n", + "vs += jnp.array([-0.05, 0.1, 0.0])\n", "\n", "print(vs.shape)\n", "print((b3d.RENDERER.intrinsics.height, b3d.RENDERER.intrinsics.width))\n", @@ -541,12 +557,12 @@ "\n", "r = 0.003\n", "# =======================================\n", - "fig, ax = plt.subplots(1,1,figsize=(9,9))\n", - "ax.spines['top'].set_visible(False)\n", - "ax.spines['right'].set_visible(False)\n", - "zoom_in(vs[:,:2], 0.02, ax=ax)\n", + "fig, ax = plt.subplots(1, 1, figsize=(9, 9))\n", + "ax.spines[\"top\"].set_visible(False)\n", + "ax.spines[\"right\"].set_visible(False)\n", + "zoom_in(vs[:, :2], 0.02, ax=ax)\n", "plot_poses(vs, sc, r=r, linewidth=2, ax=ax, q=0.98)\n", - "plot_pose(scene_ch[\"contact_params_1\"], r=r/2, c=\"red\")\n", + "plot_pose(scene_ch[\"contact_params_1\"], r=r / 2, c=\"red\")\n", "plt.scatter(*v0[:2], c=\"r\", s=10)" ] }, @@ -579,7 +595,7 @@ "source": [ "order = jnp.argsort(sc)\n", "# =================\n", - "plt.figure(figsize=(6,1))\n", + "plt.figure(figsize=(6, 1))\n", "plt.plot(sc[order])" ] }, @@ -616,7 +632,7 @@ "\n", "X, Y = tr1.retval\n", "\n", - "fig, axs = plt.subplots(1,2, figsize=(10,5))\n", + "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", "axs[0].imshow(prep_im(X, far=4.5))\n", "axs[1].imshow(prep_im(Y, far=4.5))" ] @@ -834,10 +850,10 @@ " tr1 = update(key, tr0, vs[order[i]])\n", "\n", " X, Y = tr1.retval\n", - " fig, axs = plt.subplots(1,3, figsize=(4,2))\n", + " fig, axs = plt.subplots(1, 3, figsize=(4, 2))\n", " axs[0].imshow(prep_im(X, far=4.5))\n", " axs[1].imshow(prep_im(Y, far=4.5))\n", - " axs[2].imshow(X[...,2] - Y[...,2], cmap=\"bwr\", vmin=-0.02, vmax=0.02)" + " axs[2].imshow(X[..., 2] - Y[..., 2], cmap=\"bwr\", vmin=-0.02, vmax=0.02)" ] }, { @@ -942,18 +958,18 @@ } ], "source": [ - "method = [\"nearest\", \"bilinear\",\"trilinear\", \"triangle\", \"cubic\"][0]\n", + "method = [\"nearest\", \"bilinear\", \"trilinear\", \"triangle\", \"cubic\"][0]\n", "\n", - "Y_down = scale.resize(Y[...,2], (10,10), method)\n", - "Y_up = scale.resize(Y_down, (20,20), method)\n", + "Y_down = scale.resize(Y[..., 2], (10, 10), method)\n", + "Y_up = scale.resize(Y_down, (20, 20), method)\n", "\n", "\n", "# =================================\n", - "fig, axs = plt.subplots(1,4,figsize=(10,3))\n", - "axs[0].imshow(Y[...,2], vmax=1.)\n", - "axs[1].imshow(Y_down, vmax=1.)\n", - "axs[2].imshow(Y_up, vmax=1.)\n", - "axs[3].imshow(Y[...,2] - Y_up, cmap=\"bwr\", vmin=-0.05,vmax=0.05)" + "fig, axs = plt.subplots(1, 4, figsize=(10, 3))\n", + "axs[0].imshow(Y[..., 2], vmax=1.0)\n", + "axs[1].imshow(Y_down, vmax=1.0)\n", + "axs[2].imshow(Y_up, vmax=1.0)\n", + "axs[3].imshow(Y[..., 2] - Y_up, cmap=\"bwr\", vmin=-0.05, vmax=0.05)" ] }, { diff --git a/scripts/_mkl/notebooks/33 - LH debugging.ipynb b/scripts/_mkl/notebooks/33 - LH debugging.ipynb index 7c605c7c..c910bf6a 100644 --- a/scripts/_mkl/notebooks/33 - LH debugging.ipynb +++ b/scripts/_mkl/notebooks/33 - LH debugging.ipynb @@ -60,14 +60,16 @@ "metadata": {}, "outputs": [], "source": [ - "_scaling = 1e-3\n", - "model_dir = os.path.join(b3d.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "_scaling = 1e-3\n", + "model_dir = os.path.join(b3d.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", " mesh = trimesh.load(mesh_path)\n", - " mesh.vertices *= _scaling \n", + " mesh.vertices *= _scaling\n", " meshes.append(mesh)\n", "\n", "\n", @@ -75,7 +77,7 @@ "mesh_path = os.path.join(b3d.utils.get_assets_dir(), \"sample_objs/cube.obj\")\n", "mesh = trimesh.load(mesh_path)\n", "mesh.vertices *= 1e-9\n", - "meshes.append(mesh)\n" + "meshes.append(mesh)" ] }, { @@ -103,13 +105,11 @@ "h = 200\n", "f = 300\n", "intr = b3d.Intrinsics(\n", - " width = w, height = h,\n", - " fx = f, fy = f,\n", - " cx = w/2 - 0.5, cy = h/2 - 0.5,\n", - " near = 1e-4, far = 5.0\n", + " width=w, height=h, fx=f, fy=f, cx=w / 2 - 0.5, cy=h / 2 - 0.5, near=1e-4, far=5.0\n", ")\n", "b3d.setup_renderer(intr)\n", - "for mesh in meshes: b3d.RENDERER.add_mesh(mesh, center_mesh=True)" + "for mesh in meshes:\n", + " b3d.RENDERER.add_mesh(mesh, center_mesh=True)" ] }, { @@ -119,7 +119,7 @@ "outputs": [], "source": [ "def prep_im(Y, far=5.0, eps=1e-6):\n", - " im = jnp.where(Y[:,:,2]>= far - eps, jnp.inf, Y[:,:,2])\n", + " im = jnp.where(Y[:, :, 2] >= far - eps, jnp.inf, Y[:, :, 2])\n", " return im" ] }, @@ -142,8 +142,8 @@ "data = data.item()\n", "\n", "scene_ch = data[\"choices\"]\n", - "cam = data[\"choices\"][\"camera_pose\"]\n", - "table = data[\"choices\"][\"root_pose_0\"]\n", + "cam = data[\"choices\"][\"camera_pose\"]\n", + "table = data[\"choices\"][\"root_pose_0\"]\n", "scene_model_args = data[\"args\"]" ] }, @@ -180,7 +180,7 @@ "w, tr = scene_model.importance(key, scene_ch, scene_model_args)\n", "cam, ps, inds = tr.retval\n", "\n", - "X = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + "X = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", "\n", "# =====================\n", "plt.imshow(prep_im(X))" @@ -208,19 +208,23 @@ ], "source": [ "w_orig = h_orig = 200\n", - "w_att = h_att = 50\n", - "fx = fy = 300\n", - "center = jnp.array([90., 70.])\n", + "w_att = h_att = 50\n", + "fx = fy = 300\n", + "center = jnp.array([90.0, 70.0])\n", "\n", "intr = b3d.Intrinsics(\n", - " width = 2*w_att+1, height = 2*h_att+1,\n", - " fx = fx, fy = fy,\n", - " cx = w_orig/2 - 0.5 + w_att - center[0], cy = h_orig/2 - 0.5 + h_att - center[1],\n", - " near = 1e-4, \n", - " far = 5.0\n", + " width=2 * w_att + 1,\n", + " height=2 * h_att + 1,\n", + " fx=fx,\n", + " fy=fy,\n", + " cx=w_orig / 2 - 0.5 + w_att - center[0],\n", + " cy=h_orig / 2 - 0.5 + h_att - center[1],\n", + " near=1e-4,\n", + " far=5.0,\n", ")\n", "b3d.setup_renderer(intr)\n", - "for mesh in meshes: b3d.RENDERER.add_mesh(mesh, center_mesh=True)" + "for mesh in meshes:\n", + " b3d.RENDERER.add_mesh(mesh, center_mesh=True)" ] }, { @@ -250,7 +254,7 @@ } ], "source": [ - "X = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + "X = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", "\n", "# =====================\n", "plt.imshow(prep_im(X))" @@ -272,11 +276,13 @@ "scene_model_args = (\n", " jnp.arange(2),\n", " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b3d.RENDERER.model_box_dims\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b3d.RENDERER.model_box_dims,\n", ")\n", - "scene_model = make_table_scene_model()" + "scene_model = make_table_scene_model()" ] }, { @@ -310,10 +316,10 @@ "_, tr0 = scene_model.importance(key, scene_ch, scene_model_args)\n", "\n", "cam, ps, inds = tr0.retval\n", - "Y = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + "Y = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", "\n", "# ========================\n", - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "plt.imshow(prep_im(Y, far=4.5))" ] }, @@ -328,11 +334,12 @@ "\n", "def get_linear_grid(shape, bounds, flat=False):\n", " \"\"\"Create a (linear) grid of a given shape and bounds.\"\"\"\n", - " \n", - " linspaces = [jnp.linspace(*b, d) for b,d in zip(bounds, shape)]\n", - " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing='ij'), axis=-1)\n", - " if flat: vs = vs.reshape(-1,len(shape))\n", - " \n", + "\n", + " linspaces = [jnp.linspace(*b, d) for b, d in zip(bounds, shape)]\n", + " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing=\"ij\"), axis=-1)\n", + " if flat:\n", + " vs = vs.reshape(-1, len(shape))\n", + "\n", " return vs" ] }, @@ -347,7 +354,7 @@ " diffs = argdiffs(tr0.args)\n", " (_, w, tr1, _) = tr0.update(key, ch, diffs)\n", " cam, ps, inds = tr1.retval\n", - " Y = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + " Y = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", " return Y" ] }, @@ -357,21 +364,18 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def make_eval_contact(score_func):\n", - "\n", " def _eval_contact(key, tr0, v, X):\n", " ch = genjax.choice_map({\"contact_params_1\": v})\n", " diffs = argdiffs(tr0.args)\n", " (_, w, tr1, _) = tr0.update(key, ch, diffs)\n", " cam, ps, inds = tr1.retval\n", - " Y = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", - " return score_func(Y,X)\n", + " Y = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", + " return score_func(Y, X)\n", "\n", + " eval_contact = jit(vmap(_eval_contact, (None, None, 0, None)))\n", "\n", - " eval_contact = jit(vmap(_eval_contact, (None,None,0, None)))\n", - " \n", - " return eval_contact\n" + " return eval_contact" ] }, { @@ -381,16 +385,12 @@ "outputs": [], "source": [ "# def score_against_latent(Y):\n", - " # D = X[...,2] - Y[...,2]\n", + "# D = X[...,2] - Y[...,2]\n", + "\n", "\n", - " \n", "def score_images(rendered, observed):\n", " distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", - " logps = jax.scipy.stats.norm.pdf(\n", - " distances,\n", - " loc = 0.0, \n", - " scale = 0.1\n", - " )\n", + " logps = jax.scipy.stats.norm.pdf(distances, loc=0.0, scale=0.1)\n", " logps = jnp.where(distances < 0.01, 0.1 + logps, 0.1)\n", " image_probability = logps.mean()\n", " return image_probability\n", @@ -432,11 +432,11 @@ } ], "source": [ - "dx = 0.022\n", + "dx = 0.022\n", "dhd = jnp.pi\n", "v0 = scene_ch[\"contact_params_1\"]\n", - "vs = get_linear_grid((25, 25, 25), ((-dx,dx), (-dx,dx), (-dhd,dhd)), flat=True)\n", - "vs += jnp.array([-0.042,0.105, 0.0])\n", + "vs = get_linear_grid((25, 25, 25), ((-dx, dx), (-dx, dx), (-dhd, dhd)), flat=True)\n", + "vs += jnp.array([-0.042, 0.105, 0.0])\n", "\n", "print(vs.shape)\n", "print((b3d.RENDERER.intrinsics.height, b3d.RENDERER.intrinsics.width))\n", @@ -445,12 +445,12 @@ "\n", "r = 0.003\n", "# =======================================\n", - "fig, ax = plt.subplots(1,1,figsize=(6,6))\n", - "ax.spines['top'].set_visible(False)\n", - "ax.spines['right'].set_visible(False)\n", - "zoom_in(vs[:,:2], 0.01, ax=ax)\n", + "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", + "ax.spines[\"top\"].set_visible(False)\n", + "ax.spines[\"right\"].set_visible(False)\n", + "zoom_in(vs[:, :2], 0.01, ax=ax)\n", "plot_poses(vs, sc, r=r, linewidth=2, ax=ax, q=0.98)\n", - "plot_pose(scene_ch[\"contact_params_1\"], r=r/2, c=\"red\")\n", + "plot_pose(scene_ch[\"contact_params_1\"], r=r / 2, c=\"red\")\n", "# plt.scatter(*v0[:2], c=\"r\", s=10)" ] }, @@ -483,7 +483,7 @@ "source": [ "order = jnp.argsort(sc)\n", "# =================\n", - "plt.figure(figsize=(6,1))\n", + "plt.figure(figsize=(6, 1))\n", "plt.plot(sc[order])" ] }, @@ -520,7 +520,7 @@ "\n", "X, Y = tr1.retval\n", "\n", - "fig, axs = plt.subplots(1,2, figsize=(10,5))\n", + "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", "axs[0].imshow(prep_im(X, far=4.5))\n", "axs[1].imshow(prep_im(Y, far=4.5))" ] @@ -738,10 +738,10 @@ " tr1 = update(key, tr0, vs[order[i]])\n", "\n", " X, Y = tr1.retval\n", - " fig, axs = plt.subplots(1,3, figsize=(4,2))\n", + " fig, axs = plt.subplots(1, 3, figsize=(4, 2))\n", " axs[0].imshow(prep_im(X, far=4.5))\n", " axs[1].imshow(prep_im(Y, far=4.5))\n", - " axs[2].imshow(X[...,2] - Y[...,2], cmap=\"bwr\", vmin=-0.02, vmax=0.02)" + " axs[2].imshow(X[..., 2] - Y[..., 2], cmap=\"bwr\", vmin=-0.02, vmax=0.02)" ] }, { @@ -846,18 +846,18 @@ } ], "source": [ - "method = [\"nearest\", \"bilinear\",\"trilinear\", \"triangle\", \"cubic\"][0]\n", + "method = [\"nearest\", \"bilinear\", \"trilinear\", \"triangle\", \"cubic\"][0]\n", "\n", - "Y_down = scale.resize(Y[...,2], (10,10), method)\n", - "Y_up = scale.resize(Y_down, (20,20), method)\n", + "Y_down = scale.resize(Y[..., 2], (10, 10), method)\n", + "Y_up = scale.resize(Y_down, (20, 20), method)\n", "\n", "\n", "# =================================\n", - "fig, axs = plt.subplots(1,4,figsize=(10,3))\n", - "axs[0].imshow(Y[...,2], vmax=1.)\n", - "axs[1].imshow(Y_down, vmax=1.)\n", - "axs[2].imshow(Y_up, vmax=1.)\n", - "axs[3].imshow(Y[...,2] - Y_up, cmap=\"bwr\", vmin=-0.05,vmax=0.05)" + "fig, axs = plt.subplots(1, 4, figsize=(10, 3))\n", + "axs[0].imshow(Y[..., 2], vmax=1.0)\n", + "axs[1].imshow(Y_down, vmax=1.0)\n", + "axs[2].imshow(Y_up, vmax=1.0)\n", + "axs[3].imshow(Y[..., 2] - Y_up, cmap=\"bwr\", vmin=-0.05, vmax=0.05)" ] }, { diff --git a/scripts/_mkl/notebooks/33b - LH debugging step func.ipynb b/scripts/_mkl/notebooks/33b - LH debugging step func.ipynb index 712df0ef..afaaab76 100644 --- a/scripts/_mkl/notebooks/33b - LH debugging step func.ipynb +++ b/scripts/_mkl/notebooks/33b - LH debugging step func.ipynb @@ -60,14 +60,16 @@ "metadata": {}, "outputs": [], "source": [ - "_scaling = 1e-3\n", - "model_dir = os.path.join(b3d.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "_scaling = 1e-3\n", + "model_dir = os.path.join(b3d.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", " mesh = trimesh.load(mesh_path)\n", - " mesh.vertices *= _scaling \n", + " mesh.vertices *= _scaling\n", " meshes.append(mesh)\n", "\n", "\n", @@ -75,7 +77,7 @@ "mesh_path = os.path.join(b3d.utils.get_assets_dir(), \"sample_objs/cube.obj\")\n", "mesh = trimesh.load(mesh_path)\n", "mesh.vertices *= 1e-9\n", - "meshes.append(mesh)\n" + "meshes.append(mesh)" ] }, { @@ -103,13 +105,11 @@ "h = 200\n", "f = 300\n", "intr = b3d.Intrinsics(\n", - " width = w, height = h,\n", - " fx = f, fy = f,\n", - " cx = w/2 - 0.5, cy = h/2 - 0.5,\n", - " near = 1e-4, far = 5.0\n", + " width=w, height=h, fx=f, fy=f, cx=w / 2 - 0.5, cy=h / 2 - 0.5, near=1e-4, far=5.0\n", ")\n", "b3d.setup_renderer(intr)\n", - "for mesh in meshes: b3d.RENDERER.add_mesh(mesh, center_mesh=True)" + "for mesh in meshes:\n", + " b3d.RENDERER.add_mesh(mesh, center_mesh=True)" ] }, { @@ -119,7 +119,7 @@ "outputs": [], "source": [ "def prep_im(Y, far=5.0, eps=1e-6):\n", - " im = jnp.where(Y[:,:,2]>= far - eps, jnp.inf, Y[:,:,2])\n", + " im = jnp.where(Y[:, :, 2] >= far - eps, jnp.inf, Y[:, :, 2])\n", " return im" ] }, @@ -142,8 +142,8 @@ "data = data.item()\n", "\n", "scene_ch = data[\"choices\"]\n", - "cam = data[\"choices\"][\"camera_pose\"]\n", - "table = data[\"choices\"][\"root_pose_0\"]\n", + "cam = data[\"choices\"][\"camera_pose\"]\n", + "table = data[\"choices\"][\"root_pose_0\"]\n", "scene_model_args = data[\"args\"]" ] }, @@ -180,7 +180,7 @@ "w, tr = scene_model.importance(key, scene_ch, scene_model_args)\n", "cam, ps, inds = tr.retval\n", "\n", - "X = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + "X = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", "\n", "# =====================\n", "plt.imshow(prep_im(X))" @@ -195,7 +195,7 @@ "def get_patch(X, c, w, h):\n", " i = int(c[1])\n", " j = int(c[0])\n", - " patch = X[ i-h : i+h+1 , j-w : j+w+1 ]\n", + " patch = X[i - h : i + h + 1, j - w : j + w + 1]\n", " return patch" ] }, @@ -226,25 +226,25 @@ } ], "source": [ - "c = jnp.array([90,70])\n", - "att_w = 5\n", - "att_h = 5\n", + "c = jnp.array([90, 70])\n", + "att_w = 5\n", + "att_h = 5\n", "\n", "patch = get_patch(X, c, att_w, att_h)\n", "\n", "# =====================\n", - "fig, axs = plt.subplots(1, 2, figsize=(20,10))\n", - "axs[0].set_xlim(c[0]-att_w-20, c[0]+att_w+20)\n", - "axs[0].set_ylim(c[1]+att_h+20, c[1]-att_h-20)\n", + "fig, axs = plt.subplots(1, 2, figsize=(20, 10))\n", + "axs[0].set_xlim(c[0] - att_w - 20, c[0] + att_w + 20)\n", + "axs[0].set_ylim(c[1] + att_h + 20, c[1] - att_h - 20)\n", "axs[0].imshow(prep_im(X))\n", "axs[0].scatter(*c, c=\"r\", s=20, marker=\"x\")\n", - "axs[0].scatter(c[0]-att_w, c[1]-att_h, c=\"r\", s=10)\n", - "axs[0].scatter(c[0]-att_w, c[1]+att_h, c=\"r\", s=10)\n", - "axs[0].scatter(c[0]+att_w, c[1]-att_h, c=\"r\", s=10)\n", - "axs[0].scatter(c[0]+att_w, c[1]+att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] - att_w, c[1] - att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] - att_w, c[1] + att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] + att_w, c[1] - att_h, c=\"r\", s=10)\n", + "axs[0].scatter(c[0] + att_w, c[1] + att_h, c=\"r\", s=10)\n", "\n", "axs[1].imshow(prep_im(patch))\n", - "axs[1].scatter(att_w,att_h, c=\"r\", s=20, marker=\"x\")\n" + "axs[1].scatter(att_w, att_h, c=\"r\", s=20, marker=\"x\")" ] }, { @@ -279,7 +279,7 @@ "\n", "sensor_model = make_simple_step_sensor_model(5.0)\n", "key = keysplit(key)\n", - "X_ = sensor_model(patch[...,:3], 1e-6, 0.0)(key)\n", + "X_ = sensor_model(patch[..., :3], 1e-6, 0.0)(key)\n", "\n", "# =====================\n", "plt.imshow(prep_im(X_, far=4.5))" @@ -300,6 +300,7 @@ "source": [ "from bayes3d._mkl.simple_likelihood import wrap_into_dist\n", "\n", + "\n", "def score_images(observed, rendered, sig, *args):\n", " distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", " return (distances < sig).mean()\n", @@ -339,22 +340,20 @@ "\n", "\n", "def make_patch_model(scene_model_args):\n", - "\n", - " scene_model = make_table_scene_model()\n", + " scene_model = make_table_scene_model()\n", " # sensor_model = make_simple_step_sensor_model(5.0)\n", " sensor_model = wrap_into_dist(score_images)\n", "\n", " @genjax.gen\n", " def patch_model(sig, out):\n", - "\n", " cam, ps, inds = scene_model(*scene_model_args) @ \"scene\"\n", - " Y = b3d.RENDERER.render(inv(cam) @ ps , inds)\n", + " Y = b3d.RENDERER.render(inv(cam) @ ps, inds)\n", " # X = sensor_model.inline(Y[...,:3], sig, out)\n", - " X = sensor_model(Y[...,:3], sig, out) @ \"X\"\n", + " X = sensor_model(Y[..., :3], sig, out) @ \"X\"\n", "\n", - " return X,Y\n", + " return X, Y\n", "\n", - " return patch_model\n" + " return patch_model" ] }, { @@ -379,19 +378,23 @@ ], "source": [ "w_orig = h_orig = 200\n", - "w_att = h_att = 5\n", - "fx = fy = 300\n", - "center = jnp.array([90., 70.])\n", + "w_att = h_att = 5\n", + "fx = fy = 300\n", + "center = jnp.array([90.0, 70.0])\n", "\n", "intr = b3d.Intrinsics(\n", - " width = 2*w_att+1, height = 2*h_att+1,\n", - " fx = fx, fy = fy,\n", - " cx = w_orig/2 - 0.5 + w_att - center[0], cy = h_orig/2 - 0.5 + h_att - center[1],\n", - " near = 1e-4, \n", - " far = 5.0\n", + " width=2 * w_att + 1,\n", + " height=2 * h_att + 1,\n", + " fx=fx,\n", + " fy=fy,\n", + " cx=w_orig / 2 - 0.5 + w_att - center[0],\n", + " cy=h_orig / 2 - 0.5 + h_att - center[1],\n", + " near=1e-4,\n", + " far=5.0,\n", ")\n", "b3d.setup_renderer(intr)\n", - "for mesh in meshes: b3d.RENDERER.add_mesh(mesh, center_mesh=True)" + "for mesh in meshes:\n", + " b3d.RENDERER.add_mesh(mesh, center_mesh=True)" ] }, { @@ -403,9 +406,11 @@ "scene_model_args = (\n", " jnp.arange(2),\n", " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b3d.RENDERER.model_box_dims\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b3d.RENDERER.model_box_dims,\n", ")\n", "\n", "model = make_patch_model(scene_model_args)" @@ -445,22 +450,26 @@ } ], "source": [ - "ch = genjax.choice_map({\n", - " \"scene\": genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"camera_pose\": cam,\n", - " \"root_pose_0\": table,\n", - " \"id_0\": jnp.int32(21), # Atomic Table\n", - " \"id_1\": jnp.int32(13), # Mug\n", - " \"face_parent_1\": 1, # That's the top face of the table\n", - " \"face_child_1\": 3, # That's a bottom face of the mug\n", - " \"contact_params_1\": scene_ch[\"contact_params_1\"],\n", - " }),\n", - " # \"X\": genjax.vector_choice_map({\n", - " # \"measurement\": \n", - " # }) \n", - "})\n", + "ch = genjax.choice_map(\n", + " {\n", + " \"scene\": genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"camera_pose\": cam,\n", + " \"root_pose_0\": table,\n", + " \"id_0\": jnp.int32(21), # Atomic Table\n", + " \"id_1\": jnp.int32(13), # Mug\n", + " \"face_parent_1\": 1, # That's the top face of the table\n", + " \"face_child_1\": 3, # That's a bottom face of the mug\n", + " \"contact_params_1\": scene_ch[\"contact_params_1\"],\n", + " }\n", + " ),\n", + " # \"X\": genjax.vector_choice_map({\n", + " # \"measurement\":\n", + " # })\n", + " }\n", + ")\n", "\n", "\n", "sig = 1e-6\n", @@ -473,7 +482,7 @@ "\n", "P, _ = tr0.retval\n", "# ========================\n", - "plt.figure(figsize=(1,1))\n", + "plt.figure(figsize=(1, 1))\n", "plt.imshow(prep_im(P, far=4.5))" ] }, @@ -488,11 +497,12 @@ "\n", "def get_linear_grid(shape, bounds, flat=False):\n", " \"\"\"Create a (linear) grid of a given shape and bounds.\"\"\"\n", - " \n", - " linspaces = [jnp.linspace(*b, d) for b,d in zip(bounds, shape)]\n", - " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing='ij'), axis=-1)\n", - " if flat: vs = vs.reshape(-1,len(shape))\n", - " \n", + "\n", + " linspaces = [jnp.linspace(*b, d) for b, d in zip(bounds, shape)]\n", + " vs = jnp.stack(jnp.meshgrid(*linspaces, indexing=\"ij\"), axis=-1)\n", + " if flat:\n", + " vs = vs.reshape(-1, len(shape))\n", + "\n", " return vs" ] }, @@ -503,11 +513,15 @@ "outputs": [], "source": [ "def update(key, tr0, v):\n", - " ch = genjax.choice_map({ \"scene\": \n", - " genjax.choice_map({\n", - " \"contact_params_1\": v,\n", - " })\n", - " })\n", + " ch = genjax.choice_map(\n", + " {\n", + " \"scene\": genjax.choice_map(\n", + " {\n", + " \"contact_params_1\": v,\n", + " }\n", + " )\n", + " }\n", + " )\n", " diffs = argdiffs(tr0.args)\n", " (_, w, tr1, _) = tr0.update(key, ch, diffs)\n", " return tr1" @@ -520,18 +534,22 @@ "outputs": [], "source": [ "def _eval_contact(key, tr0, v):\n", - " ch = genjax.choice_map({ \"scene\": \n", - " genjax.choice_map({\n", - " \"contact_params_1\": v,\n", - " })\n", - " })\n", + " ch = genjax.choice_map(\n", + " {\n", + " \"scene\": genjax.choice_map(\n", + " {\n", + " \"contact_params_1\": v,\n", + " }\n", + " )\n", + " }\n", + " )\n", " diffs = argdiffs(tr0.args)\n", " (_, w, tr1, _) = tr0.update(key, ch, diffs)\n", "\n", " return tr1.get_score()\n", "\n", "\n", - "eval_contact = jit(vmap(_eval_contact, (None,None,0)))\n" + "eval_contact = jit(vmap(_eval_contact, (None, None, 0)))" ] }, { @@ -596,11 +614,11 @@ } ], "source": [ - "dx = 0.03\n", + "dx = 0.03\n", "dhd = jnp.pi\n", "v0 = scene_ch[\"contact_params_1\"]\n", - "vs = get_linear_grid((35, 35, 35), ((-dx,dx), (-dx,dx), (-dhd,dhd)), flat=True)\n", - "vs += jnp.array([-0.04,0.11, 0.0])\n", + "vs = get_linear_grid((35, 35, 35), ((-dx, dx), (-dx, dx), (-dhd, dhd)), flat=True)\n", + "vs += jnp.array([-0.04, 0.11, 0.0])\n", "\n", "print(vs.shape)\n", "print((b3d.RENDERER.intrinsics.height, b3d.RENDERER.intrinsics.width))\n", @@ -615,15 +633,14 @@ "\"\"\")\n", "\n", "\n", - "\n", "r = 0.003\n", "# =======================================\n", - "fig, ax = plt.subplots(1,1,figsize=(9,9))\n", - "ax.spines['top'].set_visible(False)\n", - "ax.spines['right'].set_visible(False)\n", - "zoom_in(vs[:,:2], 0.02, ax=ax)\n", + "fig, ax = plt.subplots(1, 1, figsize=(9, 9))\n", + "ax.spines[\"top\"].set_visible(False)\n", + "ax.spines[\"right\"].set_visible(False)\n", + "zoom_in(vs[:, :2], 0.02, ax=ax)\n", "plot_poses(vs, sc, r=r, linewidth=2, ax=ax, q=0.98)\n", - "plot_pose(scene_ch[\"contact_params_1\"], r=r/2, c=\"red\")\n", + "plot_pose(scene_ch[\"contact_params_1\"], r=r / 2, c=\"red\")\n", "plt.scatter(*v0[:2], c=\"r\", s=10)" ] }, @@ -665,7 +682,7 @@ "\n", "order = jnp.argsort(sc)\n", "# =================\n", - "plt.figure(figsize=(6,1))\n", + "plt.figure(figsize=(6, 1))\n", "plt.plot(sc[order][-100:])" ] }, @@ -702,7 +719,7 @@ "\n", "X, Y = tr1.retval\n", "\n", - "fig, axs = plt.subplots(1,2, figsize=(10,5))\n", + "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", "axs[0].imshow(prep_im(X, far=4.5))\n", "axs[1].imshow(prep_im(Y, far=4.5))" ] @@ -920,11 +937,11 @@ " tr1 = update(key, tr0, vs[order[i]])\n", "\n", " X, Y = tr1.retval\n", - " fig, axs = plt.subplots(1,4, figsize=(6,2))\n", + " fig, axs = plt.subplots(1, 4, figsize=(6, 2))\n", " axs[0].imshow(prep_im(X, far=4.5))\n", " axs[1].imshow(prep_im(Y, far=4.5))\n", - " axs[2].imshow(X[...,2] - Y[...,2], cmap=\"bwr\", vmin=-sig, vmax=sig)\n", - " axs[3].imshow(jnp.linalg.norm(X[...,:3] - Y[...,:3], axis=-1) < sig)" + " axs[2].imshow(X[..., 2] - Y[..., 2], cmap=\"bwr\", vmin=-sig, vmax=sig)\n", + " axs[3].imshow(jnp.linalg.norm(X[..., :3] - Y[..., :3], axis=-1) < sig)" ] }, { diff --git a/scripts/_mkl/notebooks/99 - Traceviz Test.ipynb b/scripts/_mkl/notebooks/99 - Traceviz Test.ipynb index 5efa6c35..964bf12f 100644 --- a/scripts/_mkl/notebooks/99 - Traceviz Test.ipynb +++ b/scripts/_mkl/notebooks/99 - Traceviz Test.ipynb @@ -8,7 +8,7 @@ "source": [ "import traceviz.client\n", "import numpy as np\n", - "from traceviz.proto import viz_pb2\n", + "from traceviz.proto import viz_pb2\n", "import json" ] }, @@ -56,19 +56,28 @@ ], "source": [ "msg = viz_pb2.Message()\n", - "msg.pytree.MergeFrom(traceviz.client.to_pytree_msg({\"type\": \"setup\",}))\n", + "msg.pytree.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"type\": \"setup\",\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", - "\n", + "print(\"response: \", stub.Broadcast(msg))\n", "\n", "\n", "msg = viz_pb2.Message()\n", - "msg.pytree.MergeFrom(traceviz.client.to_pytree_msg({\n", - " \"type\": \"TEST\",\n", - " \"data\": np.ones((10,10)),\n", - " }))\n", + "msg.pytree.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"type\": \"TEST\",\n", + " \"data\": np.ones((10, 10)),\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -98,10 +107,10 @@ "t = 100\n", "n = 100\n", "key, keys = keysplit(key, 1, 3)\n", - "x0 = jax.random.normal(keys[0], (1,n,3))\n", - "xs = x0 + jnp.cumsum(0.03*jax.random.uniform(keys[0], (t,n,3)), axis=0)\n", - "cs = jnp.tile( jax.random.uniform(keys[1], (n,4)), (t,1,1))\n", - "ss = 0.1*jnp.tile( jax.random.uniform(keys[2], (n,4)), (t,1,1))\n", + "x0 = jax.random.normal(keys[0], (1, n, 3))\n", + "xs = x0 + jnp.cumsum(0.03 * jax.random.uniform(keys[0], (t, n, 3)), axis=0)\n", + "cs = jnp.tile(jax.random.uniform(keys[1], (n, 4)), (t, 1, 1))\n", + "ss = 0.1 * jnp.tile(jax.random.uniform(keys[2], (n, 4)), (t, 1, 1))\n", "\n", "xs.shape, cs.shape" ] @@ -124,20 +133,26 @@ ], "source": [ "msg = viz_pb2.Message()\n", - "msg.pytree.MergeFrom(traceviz.client.to_pytree_msg({\"type\": \"setup\",}))\n", + "msg.pytree.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"type\": \"setup\",\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "\n", "\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"animated spheres\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'centers': np.array(xs), \n", - " 'colors': np.array(cs),\n", - " 'scales': np.array(ss)\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\"centers\": np.array(xs), \"colors\": np.array(cs), \"scales\": np.array(ss)}\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -1830,6 +1845,7 @@ ], "source": [ "import bayes3d\n", + "\n", "r3d_path = \"data/2023-09-05--16-08-21.r3d.zip\"\n", "colors, depths, poses, intrinsics = bayes3d.utils.load_r3d(r3d_path)" ] @@ -1852,9 +1868,9 @@ ], "source": [ "depths.shape\n", - "t=0\n", + "t = 0\n", "\n", - "plt.figure(figsize=(10,2))\n", + "plt.figure(figsize=(10, 2))\n", "plt.hist(depths[t].flatten(), bins=100);" ] }, @@ -1866,7 +1882,7 @@ "source": [ "import bayes3d\n", "from bayes3d.transforms_3d import unproject_depth, apply_transform\n", - "from bayes3d.utils import resize\n" + "from bayes3d.utils import resize" ] }, { @@ -1888,8 +1904,13 @@ "metadata": {}, "outputs": [], "source": [ - "cs = np.array([resize(colors[i], int(intrinsics.height), int(intrinsics.width)) for i in range(len(colors))])\n", - "cs = np.concatenate([cs/255, 1.*np.ones((*cs.shape[:3], 1))], axis=-1)" + "cs = np.array(\n", + " [\n", + " resize(colors[i], int(intrinsics.height), int(intrinsics.width))\n", + " for i in range(len(colors))\n", + " ]\n", + ")\n", + "cs = np.concatenate([cs / 255, 1.0 * np.ones((*cs.shape[:3], 1))], axis=-1)" ] }, { @@ -1910,7 +1931,7 @@ ], "source": [ "T = xs.shape[0]\n", - "N = xs.shape[1]*xs.shape[2] \n", + "N = xs.shape[1] * xs.shape[2]\n", "T, N, xs.shape, cs.shape, poses.shape" ] }, @@ -1922,7 +1943,7 @@ "source": [ "from scipy.spatial.transform import Rotation as Rot\n", "\n", - "A = Rot.from_euler('xyz', [0,0, 0], degrees=True).as_matrix()" + "A = Rot.from_euler(\"xyz\", [0, 0, 0], degrees=True).as_matrix()" ] }, { @@ -1935,7 +1956,7 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -1960,13 +1981,17 @@ "for t in [0]:\n", " msg = viz_pb2.Message()\n", " msg.payload.json = json.dumps({\"type\": \"spheres\"})\n", - " msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'centers': np.array(xs).reshape(T,N,3)[t,:k]@A.T, \n", - " 'colors': np.array(cs).reshape(T,N,4)[t,:k], \n", - " 'scales': 0.0025*np.array(ss).reshape(T,N,1)[t,:k], \n", - " }))\n", + " msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": np.array(xs).reshape(T, N, 3)[t, :k] @ A.T,\n", + " \"colors\": np.array(cs).reshape(T, N, 4)[t, :k],\n", + " \"scales\": 0.0025 * np.array(ss).reshape(T, N, 1)[t, :k],\n", + " }\n", + " )\n", + " )\n", " stub = traceviz.client.connect()\n", - " print('response: ', stub.Broadcast(msg))" + " print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -1997,7 +2022,11 @@ "outputs": [], "source": [ "from sklearn.mixture import GaussianMixture\n", - "from bayes3d._mkl.trimesh_to_gaussians import get_mean_colors, pack_transform, ellipsoid_embedding\n", + "from bayes3d._mkl.trimesh_to_gaussians import (\n", + " get_mean_colors,\n", + " pack_transform,\n", + " ellipsoid_embedding,\n", + ")\n", "\n", "\n", "key = keysplit(key)\n", @@ -2007,18 +2036,21 @@ "t = 0\n", "key = keysplit(key)\n", "n_components = 500\n", - "noise = 0.0; \n", - "X = xs[t].reshape(-1,3)\n", - "X = X + np.random.randn(*X.shape)*noise\n", - "means_init = X[np.random.choice(len(X), n_components, replace=False)]\n", + "noise = 0.0\n", + "X = xs[t].reshape(-1, 3)\n", + "X = X + np.random.randn(*X.shape) * noise\n", + "means_init = X[np.random.choice(len(X), n_components, replace=False)]\n", "\n", "\n", "# Fit the GMM\n", "# -----------\n", - "gm = GaussianMixture(n_components=n_components, \n", - " tol=1e-3, max_iter=100, \n", - " covariance_type=\"full\", \n", - " means_init=means_init).fit(X)\n" + "gm = GaussianMixture(\n", + " n_components=n_components,\n", + " tol=1e-3,\n", + " max_iter=100,\n", + " covariance_type=\"full\",\n", + " means_init=means_init,\n", + ").fit(X)" ] }, { @@ -2039,12 +2071,14 @@ ], "source": [ "ws = gm.weights_\n", - "mus = gm.means_\n", - "covs = gm.covariances_\n", - "labels = gm.predict(X)\n", - "choleskys = vmap(ellipsoid_embedding)(covs)\n", - "transforms = vmap(pack_transform, (0,0,None))(mus, choleskys, 2.0)\n", - "mean_colors, nums = get_mean_colors(np.array(cs[t]).reshape(-1,4), gm.n_components, labels)\n", + "mus = gm.means_\n", + "covs = gm.covariances_\n", + "labels = gm.predict(X)\n", + "choleskys = vmap(ellipsoid_embedding)(covs)\n", + "transforms = vmap(pack_transform, (0, 0, None))(mus, choleskys, 2.0)\n", + "mean_colors, nums = get_mean_colors(\n", + " np.array(cs[t]).reshape(-1, 4), gm.n_components, labels\n", + ")\n", "valid = nums > 0\n", "valid.sum()\n", "transforms.shape, mean_colors.shape, valid.shape" @@ -2077,9 +2111,9 @@ } ], "source": [ - "ms = jnp.max(jnp.linalg.norm(choleskys, axis=1), axis=-1) \n", + "ms = jnp.max(jnp.linalg.norm(choleskys, axis=1), axis=-1)\n", "\n", - "plt.figure(figsize=(2,2))\n", + "plt.figure(figsize=(2, 2))\n", "plt.scatter(ms, ws)" ] }, @@ -2107,15 +2141,19 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"gaussians\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'transforms': np.array(transforms )[sub][None], \n", - " 'colors': np.array(mean_colors)[sub][None] \n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"transforms\": np.array(transforms)[sub][None],\n", + " \"colors\": np.array(mean_colors)[sub][None],\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -2146,16 +2184,20 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"animated spheres\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'centers': np.array(xs).reshape(T,N,3)[:50,:k]@A.T, \n", - " 'colors': np.array(cs).reshape(T,N,4)[:50,:k], \n", - " 'scales': 0.002*np.array(ss).reshape(T,N,1)[:50,:k], \n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": np.array(xs).reshape(T, N, 3)[:50, :k] @ A.T,\n", + " \"colors\": np.array(cs).reshape(T, N, 4)[:50, :k],\n", + " \"scales\": 0.002 * np.array(ss).reshape(T, N, 1)[:50, :k],\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -2192,19 +2234,29 @@ "\n", "data = np.load(\"data/ramen.npz\")\n", "depths = data[\"depths\"]\n", - "colors= data[\"colors\"]\n", + "colors = data[\"colors\"]\n", "\n", "intrinsics = bayes3d.Intrinsics(*data[\"intrinsics\"])\n", "\n", - "xs = np.array([bayes3d.transforms_3d.unproject_depth(depths[i], intrinsics) for i in range(len(colors))])\n", - "cs = np.array([bayes3d.utils.resize(colors[i], int(intrinsics.height), int(intrinsics.width)) for i in range(len(colors))])\n", - "cs = np.concatenate([cs, np.ones((*cs.shape[:3], 1))], axis=-1)/255\n", + "xs = np.array(\n", + " [\n", + " bayes3d.transforms_3d.unproject_depth(depths[i], intrinsics)\n", + " for i in range(len(colors))\n", + " ]\n", + ")\n", + "cs = np.array(\n", + " [\n", + " bayes3d.utils.resize(colors[i], int(intrinsics.height), int(intrinsics.width))\n", + " for i in range(len(colors))\n", + " ]\n", + ")\n", + "cs = np.concatenate([cs, np.ones((*cs.shape[:3], 1))], axis=-1) / 255\n", "ss = np.ones(xs.shape[:3] + (1,))\n", "\n", "xs[:, :, :, 1] *= -1\n", "\n", "T = xs.shape[0]\n", - "N = xs.shape[1]*xs.shape[2] \n", + "N = xs.shape[1] * xs.shape[2]\n", "T, N, depths.shape, colors.shape, xs.shape, cs.shape" ] }, @@ -2223,7 +2275,7 @@ "source": [ "from scipy.spatial.transform import Rotation as Rot\n", "\n", - "A = Rot.from_euler('xyz', [45, 0, 0], degrees=True).as_matrix()" + "A = Rot.from_euler(\"xyz\", [45, 0, 0], degrees=True).as_matrix()" ] }, { @@ -2248,16 +2300,20 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"spheres\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'centers': xs.reshape(T,N,3)[0,:k]@A.T, \n", - " 'colors': cs.reshape(T,N,4)[0,:k], \n", - " 'scales': 0.002*ss.reshape(T,N,1)[0,:k], \n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": xs.reshape(T, N, 3)[0, :k] @ A.T,\n", + " \"colors\": cs.reshape(T, N, 4)[0, :k],\n", + " \"scales\": 0.002 * ss.reshape(T, N, 1)[0, :k],\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -2282,16 +2338,20 @@ "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"animated spheres\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'centers': xs.reshape(T,N,3)[:,:k]@A.T, \n", - " 'colors': cs.reshape(T,N,4)[:,:k], \n", - " 'scales': 0.002*ss.reshape(T,N,1)[:,:k], \n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\n", + " \"centers\": xs.reshape(T, N, 3)[:, :k] @ A.T,\n", + " \"colors\": cs.reshape(T, N, 4)[:, :k],\n", + " \"scales\": 0.002 * ss.reshape(T, N, 1)[:, :k],\n", + " }\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -2302,19 +2362,20 @@ "source": [ "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " \"test\": \"I am a test string\"\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg({\"test\": \"I am a test string\"})\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"animated gaussians\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'transforms': np.array(transforms[:]), \n", - " 'colors': np.array(cs[:])\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\"transforms\": np.array(transforms[:]), \"colors\": np.array(cs[:])}\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { @@ -2339,16 +2400,18 @@ "t = 100\n", "n = 100\n", "key, keys = keysplit(key, 1, 3)\n", - "x0 = jax.random.normal(keys[0], (1,n,3))\n", - "xs = x0 + jnp.cumsum(0.03*jax.random.uniform(keys[0], (t,n,3)), axis=0)\n", - "cs = jnp.tile( jax.random.uniform(keys[1], (n,4)), (t,1,1))\n", - "ss = 0.1*jnp.tile( jax.random.uniform(keys[2], (n,4)), (t,1,1))\n", + "x0 = jax.random.normal(keys[0], (1, n, 3))\n", + "xs = x0 + jnp.cumsum(0.03 * jax.random.uniform(keys[0], (t, n, 3)), axis=0)\n", + "cs = jnp.tile(jax.random.uniform(keys[1], (n, 4)), (t, 1, 1))\n", + "ss = 0.1 * jnp.tile(jax.random.uniform(keys[2], (n, 4)), (t, 1, 1))\n", "\n", "\n", - "choleskys = jax.random.normal(key, (n,3,3))\n", - "choleskys = jnp.tile(choleskys, (t,1,1,1))\n", - "transforms = vmap(pack_transform, (0,0,None))(xs.reshape(-1,3), choleskys.reshape(-1,3,3), 0.26)\n", - "transforms = transforms.reshape(t,n,4,4)\n", + "choleskys = jax.random.normal(key, (n, 3, 3))\n", + "choleskys = jnp.tile(choleskys, (t, 1, 1, 1))\n", + "transforms = vmap(pack_transform, (0, 0, None))(\n", + " xs.reshape(-1, 3), choleskys.reshape(-1, 3, 3), 0.26\n", + ")\n", + "transforms = transforms.reshape(t, n, 4, 4)\n", "transforms.shape, choleskys.shape" ] }, @@ -2375,19 +2438,20 @@ "source": [ "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"setup\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " \"test\": \"I am a test string\"\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg({\"test\": \"I am a test string\"})\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))\n", + "print(\"response: \", stub.Broadcast(msg))\n", "msg = viz_pb2.Message()\n", "msg.payload.json = json.dumps({\"type\": \"animated gaussians\"})\n", - "msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({\n", - " 'transforms': np.array(transforms[:]), \n", - " 'colors': np.array(cs[:])\n", - "}))\n", + "msg.payload.data.MergeFrom(\n", + " traceviz.client.to_pytree_msg(\n", + " {\"transforms\": np.array(transforms[:]), \"colors\": np.array(cs[:])}\n", + " )\n", + ")\n", "stub = traceviz.client.connect()\n", - "print('response: ', stub.Broadcast(msg))" + "print(\"response: \", stub.Broadcast(msg))" ] }, { diff --git a/scripts/_mkl/notebooks/kubric/00 - Kubric Utils.ipynb b/scripts/_mkl/notebooks/kubric/00 - Kubric Utils.ipynb index c6c7dee7..16f82641 100644 --- a/scripts/_mkl/notebooks/kubric/00 - Kubric Utils.ipynb +++ b/scripts/_mkl/notebooks/kubric/00 - Kubric Utils.ipynb @@ -52,26 +52,27 @@ "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "\n", - "# This method is from the Kubric repo: \n", + "# This method is from the Kubric repo:\n", "# > kubricchallenges/movi/movi_c_worker.py\n", "def get_linear_camera_motion_start_end(\n", " movement_speed: float,\n", - " inner_radius: float = 8.,\n", - " outer_radius: float = 12.,\n", + " inner_radius: float = 8.0,\n", + " outer_radius: float = 12.0,\n", " z_offset: float = 0.1,\n", "):\n", - " \"\"\"Sample a linear path which starts and ends within a half-sphere shell.\"\"\"\n", - " while True:\n", - " camera_start = np.array(kb.sample_point_in_half_sphere_shell(inner_radius,\n", - " outer_radius,\n", - " z_offset))\n", - " direction = rng.rand(3) - 0.5\n", - " movement = direction / np.linalg.norm(direction) * movement_speed\n", - " camera_end = camera_start + movement\n", - " if (inner_radius <= np.linalg.norm(camera_end) <= outer_radius and\n", - " camera_end[2] > z_offset):\n", - " return camera_start, camera_end\n", - "\n" + " \"\"\"Sample a linear path which starts and ends within a half-sphere shell.\"\"\"\n", + " while True:\n", + " camera_start = np.array(\n", + " kb.sample_point_in_half_sphere_shell(inner_radius, outer_radius, z_offset)\n", + " )\n", + " direction = rng.rand(3) - 0.5\n", + " movement = direction / np.linalg.norm(direction) * movement_speed\n", + " camera_end = camera_start + movement\n", + " if (\n", + " inner_radius <= np.linalg.norm(camera_end) <= outer_radius\n", + " and camera_end[2] > z_offset\n", + " ):\n", + " return camera_start, camera_end" ] }, { @@ -87,7 +88,7 @@ "# [ 0 -f/H -0.5 ]\n", "# [ 0 0 -1.0 ]]\n", "# ```\n", - "# where W,H are the sensor width and height and f is the focal length. \n", + "# where W,H are the sensor width and height and f is the focal length.\n", "# Image coordinates (u,v) and Camera coordinates (x,y,z) are related as follows:\n", "# ```\n", "# -z * (u,v,1)^T = K * (x,y,z)^T\n", @@ -95,12 +96,14 @@ "# ...\n", "# > https://github.com/google-research/kubric/blob/main/kubric/core/cameras.py#L23\n", "def z_to_xyz(z, img_shape, K):\n", - " h, w, _ = z.shape[-3:] \n", - " uv = np.stack(np.meshgrid(np.linspace(1,0, w), np.linspace(0,1, h), indexing=\"ij\"), axis=0)\n", - " uv1 = np.concatenate([uv, np.ones((1,w,h))], axis=0)\n", - " z = data_stack[\"depth\"][t].T\n", - " xyz = np.linalg.inv(K)@(-z*uv1).reshape(3,-1)\n", - " return xyz.reshape(3,w,h).T" + " h, w, _ = z.shape[-3:]\n", + " uv = np.stack(\n", + " np.meshgrid(np.linspace(1, 0, w), np.linspace(0, 1, h), indexing=\"ij\"), axis=0\n", + " )\n", + " uv1 = np.concatenate([uv, np.ones((1, w, h))], axis=0)\n", + " z = data_stack[\"depth\"][t].T\n", + " xyz = np.linalg.inv(K) @ (-z * uv1).reshape(3, -1)\n", + " return xyz.reshape(3, w, h).T" ] } ], diff --git a/scripts/_mkl/notebooks/kubric/01 - Create a Kubric Scene.ipynb b/scripts/_mkl/notebooks/kubric/01 - Create a Kubric Scene.ipynb index 051d74bb..b4c99626 100644 --- a/scripts/_mkl/notebooks/kubric/01 - Create a Kubric Scene.ipynb +++ b/scripts/_mkl/notebooks/kubric/01 - Create a Kubric Scene.ipynb @@ -55,7 +55,7 @@ "from kubric.renderer import Blender\n", "import numpy as np\n", "from kubric import ArgumentParser\n", - "from argparse import Namespace\n" + "from argparse import Namespace" ] }, { @@ -66,31 +66,31 @@ "source": [ "# --- Some configuration values\n", "# the region in which to place objects [(min), (max)]\n", - "SPAWN_REGION = [(-5, -5, 1), (5, 5, 5)]\n", - "VELOCITY_RANGE = [(-4., -4., 0.), (4., 4., 0.)]\n", + "SPAWN_REGION = [(-5, -5, 1), (5, 5, 5)]\n", + "VELOCITY_RANGE = [(-4.0, -4.0, 0.0), (4.0, 4.0, 0.0)]\n", "\n", "FLAGS = Namespace(\n", - " resolution = 256, \n", - " frame_rate = 12, \n", - " step_rate = 240, \n", - " frame_start = 0, \n", - " frame_end = 30, #24 \n", - " logging_level = 'INFO', \n", - " seed = None, \n", - " scratch_dir = '/tmp/tmp82v1zpil', \n", - " job_dir = 'output', \n", - " objects_split = 'train', \n", - " min_num_objects = 3, \n", - " max_num_objects = 10, \n", - " floor_friction = 0.3, \n", - " floor_restitution = 0.5, \n", - " backgrounds_split = 'train', \n", - " camera = 'fixed_random', \n", - " # camera = 'linear_movement', \n", - " max_camera_movement = 4.0, \n", - " kubasic_assets = 'gs://kubric-public/assets/KuBasic/KuBasic.json', \n", - " hdri_assets = 'gs://kubric-public/assets/HDRI_haven/HDRI_haven.json', \n", - " gso_assets = 'gs://kubric-public/assets/GSO/GSO.json', \n", + " resolution=256,\n", + " frame_rate=12,\n", + " step_rate=240,\n", + " frame_start=0,\n", + " frame_end=30, # 24\n", + " logging_level=\"INFO\",\n", + " seed=None,\n", + " scratch_dir=\"/tmp/tmp82v1zpil\",\n", + " job_dir=\"output\",\n", + " objects_split=\"train\",\n", + " min_num_objects=3,\n", + " max_num_objects=10,\n", + " floor_friction=0.3,\n", + " floor_restitution=0.5,\n", + " backgrounds_split=\"train\",\n", + " camera=\"fixed_random\",\n", + " # camera = 'linear_movement',\n", + " max_camera_movement=4.0,\n", + " kubasic_assets=\"gs://kubric-public/assets/KuBasic/KuBasic.json\",\n", + " hdri_assets=\"gs://kubric-public/assets/HDRI_haven/HDRI_haven.json\",\n", + " gso_assets=\"gs://kubric-public/assets/GSO/GSO.json\",\n", " save_state=False,\n", ")" ] @@ -149,10 +149,10 @@ "scene, rng, output_dir, scratch_dir = kb.setup(FLAGS)\n", "\n", "simulator = PyBullet(scene, scratch_dir)\n", - "renderer = Blender(scene, scratch_dir, samples_per_pixel=64)\n", + "renderer = Blender(scene, scratch_dir, samples_per_pixel=64)\n", "\n", - "kubasic = kb.AssetSource.from_manifest(FLAGS.kubasic_assets)\n", - "gso = kb.AssetSource.from_manifest(FLAGS.gso_assets)\n", + "kubasic = kb.AssetSource.from_manifest(FLAGS.kubasic_assets)\n", + "gso = kb.AssetSource.from_manifest(FLAGS.gso_assets)\n", "hdri_source = kb.AssetSource.from_manifest(FLAGS.hdri_assets)" ] }, @@ -164,22 +164,23 @@ "source": [ "def get_linear_camera_motion_start_end(\n", " movement_speed: float,\n", - " inner_radius: float = 8.,\n", - " outer_radius: float = 12.,\n", + " inner_radius: float = 8.0,\n", + " outer_radius: float = 12.0,\n", " z_offset: float = 0.1,\n", "):\n", - " \"\"\"Sample a linear path which starts and ends within a half-sphere shell.\"\"\"\n", - " while True:\n", - " camera_start = np.array(kb.sample_point_in_half_sphere_shell(inner_radius,\n", - " outer_radius,\n", - " z_offset))\n", - " direction = rng.rand(3) - 0.5\n", - " movement = direction / np.linalg.norm(direction) * movement_speed\n", - " camera_end = camera_start + movement\n", - " if (inner_radius <= np.linalg.norm(camera_end) <= outer_radius and\n", - " camera_end[2] > z_offset):\n", - " return camera_start, camera_end\n", - "\n" + " \"\"\"Sample a linear path which starts and ends within a half-sphere shell.\"\"\"\n", + " while True:\n", + " camera_start = np.array(\n", + " kb.sample_point_in_half_sphere_shell(inner_radius, outer_radius, z_offset)\n", + " )\n", + " direction = rng.rand(3) - 0.5\n", + " movement = direction / np.linalg.norm(direction) * movement_speed\n", + " camera_end = camera_start + movement\n", + " if (\n", + " inner_radius <= np.linalg.norm(camera_end) <= outer_radius\n", + " and camera_end[2] > z_offset\n", + " ):\n", + " return camera_start, camera_end" ] }, { @@ -218,7 +219,6 @@ "# hdri_id = rng.choice(test_backgrounds)\n", "\n", "\n", - "\n", "background_hdri = hdri_source.create(asset_id=hdri_id)\n", "assert isinstance(background_hdri, kb.Texture)\n", "logging.info(\"Using background %s\", hdri_id)\n", @@ -226,10 +226,14 @@ "renderer._set_ambient_light_hdri(background_hdri.filename)\n", "\n", "# Dome\n", - "dome = kubasic.create(asset_id=\"dome\", name=\"dome\",\n", - " friction=FLAGS.floor_friction,\n", - " restitution=FLAGS.floor_restitution,\n", - " static=True, background=True)\n", + "dome = kubasic.create(\n", + " asset_id=\"dome\",\n", + " name=\"dome\",\n", + " friction=FLAGS.floor_friction,\n", + " restitution=FLAGS.floor_restitution,\n", + " static=True,\n", + " background=True,\n", + ")\n", "assert isinstance(dome, kb.FileBasedObject)\n", "\n", "scene += dome\n", @@ -306,18 +310,24 @@ } ], "source": [ - "ss = np.array([kb.sample_point_in_half_sphere_shell(\n", - " inner_radius=8., outer_radius=9., offset=0.1) for t in range(1_000)])\n", + "ss = np.array(\n", + " [\n", + " kb.sample_point_in_half_sphere_shell(\n", + " inner_radius=8.0, outer_radius=9.0, offset=0.1\n", + " )\n", + " for t in range(1_000)\n", + " ]\n", + ")\n", "\n", "\n", - "order = np.argsort(ss[:,2])\n", + "order = np.argsort(ss[:, 2])\n", "ss = ss[order]\n", "\n", - "fig, axs = plt.subplots(1,2, figsize=(4,2))\n", + "fig, axs = plt.subplots(1, 2, figsize=(4, 2))\n", "axs[0].set_aspect(1)\n", "axs[1].set_aspect(1)\n", - "axs[0].scatter(*ss[:,[0,1]].T, c=ss[:,2])\n", - "axs[1].scatter(*ss[np.abs(ss)[:,1]<1.][:,[0,2]].T)" + "axs[0].scatter(*ss[:, [0, 1]].T, c=ss[:, 2])\n", + "axs[1].scatter(*ss[np.abs(ss)[:, 1] < 1.0][:, [0, 2]].T)" ] }, { @@ -326,7 +336,7 @@ "metadata": {}, "outputs": [], "source": [ - "camera_position = np.array([0,-8, 1])" + "camera_position = np.array([0, -8, 1])" ] }, { @@ -346,7 +356,7 @@ "# Camera\n", "logging.info(\"Setting up the Camera...\")\n", "\n", - "scene.camera = kb.PerspectiveCamera(focal_length=35., sensor_width=32)\n", + "scene.camera = kb.PerspectiveCamera(focal_length=35.0, sensor_width=32)\n", "\n", "\n", "if FLAGS.camera == \"fixed_random\":\n", @@ -358,7 +368,7 @@ "elif FLAGS.camera == \"linear_movement\":\n", " # RDM CHOICE\n", " camera_start, camera_end = get_linear_camera_motion_start_end(\n", - " movement_speed=rng.uniform(low=0., high=FLAGS.max_camera_movement)\n", + " movement_speed=rng.uniform(low=0.0, high=FLAGS.max_camera_movement)\n", " )\n", "\n", " # linearly interpolate the camera position between these two points\n", @@ -366,13 +376,15 @@ " # we start one frame early and end one frame late to ensure that\n", " # forward and backward flow are still consistent for the last and first frames\n", " for frame in range(FLAGS.frame_start - 1, FLAGS.frame_end + 2):\n", - " interp = ((frame - FLAGS.frame_start + 1) /\n", - " (FLAGS.frame_end - FLAGS.frame_start + 3))\n", - " scene.camera.position = (interp * np.array(camera_start) +\n", - " (1 - interp) * np.array(camera_end))\n", + " interp = (frame - FLAGS.frame_start + 1) / (\n", + " FLAGS.frame_end - FLAGS.frame_start + 3\n", + " )\n", + " scene.camera.position = interp * np.array(camera_start) + (\n", + " 1 - interp\n", + " ) * np.array(camera_end)\n", " scene.camera.look_at((0, 0, 0))\n", " scene.camera.keyframe_insert(\"position\", frame)\n", - " scene.camera.keyframe_insert(\"quaternion\", frame)\n" + " scene.camera.keyframe_insert(\"quaternion\", frame)" ] }, { @@ -392,7 +404,7 @@ } ], "source": [ - "35/32" + "35 / 32" ] }, { @@ -435,15 +447,21 @@ "metadata": {}, "outputs": [], "source": [ - "obj_ids = ['Reebok_ZIGSTORM', 'Canon_225226_Ink_Cartridges_BlackColor_Cyan_Magenta_Yellow_6_count', 'Canon_225226_Ink_Cartridges_BlackColor_Cyan_Magenta_Yellow_6_count']\n", + "obj_ids = [\n", + " \"Reebok_ZIGSTORM\",\n", + " \"Canon_225226_Ink_Cartridges_BlackColor_Cyan_Magenta_Yellow_6_count\",\n", + " \"Canon_225226_Ink_Cartridges_BlackColor_Cyan_Magenta_Yellow_6_count\",\n", + "]\n", "num_objects = len(obj_ids)\n", "scales = np.ones(num_objects)\n", - "positions = np.array([\n", - " [3.1545544, 4.309582, 2.6392992],\n", - " [-3.9297907, 3.285696, 3.9528222],\n", - " [-3.3671062, 1.4849155, 2.3747466]\n", - "])\n", - "velocities = np.zeros((num_objects,3)) - positions" + "positions = np.array(\n", + " [\n", + " [3.1545544, 4.309582, 2.6392992],\n", + " [-3.9297907, 3.285696, 3.9528222],\n", + " [-3.3671062, 1.4849155, 2.3747466],\n", + " ]\n", + ")\n", + "velocities = np.zeros((num_objects, 3)) - positions" ] }, { @@ -485,12 +503,10 @@ "source": [ "# RDM CHOICE\n", "# num_objects = rng.randint(FLAGS.min_num_objects,\n", - " # FLAGS.max_num_objects+1)\n", + "# FLAGS.max_num_objects+1)\n", "logging.info(\"Randomly placing %d objects:\", num_objects)\n", "\n", "\n", - "\n", - "\n", "# Add random objects\n", "train_split, test_split = gso.get_test_split(fraction=0.1)\n", "if FLAGS.objects_split == \"train\":\n", @@ -502,28 +518,26 @@ "\n", "\n", "for i in range(num_objects):\n", - " obj = gso.create(asset_id = obj_ids[i])\n", - " \n", + " obj = gso.create(asset_id=obj_ids[i])\n", + "\n", " assert isinstance(obj, kb.FileBasedObject)\n", "\n", - " # RDM CHOICE\n", - "# scale = rng.uniform(0.75, 3.0)\n", + " # RDM CHOICE\n", + " # scale = rng.uniform(0.75, 3.0)\n", " scale = scales[i]\n", - " \n", + "\n", " obj.scale = scale / np.max(obj.bounds[1] - obj.bounds[0])\n", " obj.metadata[\"scale\"] = scale\n", " scene += obj\n", "\n", - " \n", " # RDM CHOICE???\n", " kb.move_until_no_overlap(obj, simulator, spawn_region=SPAWN_REGION, rng=rng)\n", " # initialize velocity randomly but biased towards center\n", " # obj.velocity = (rng.uniform(*VELOCITY_RANGE) -\n", - " # [obj.position[0], obj.position[1], 0])\n", + " # [obj.position[0], obj.position[1], 0])\n", " obj.velocity = velocities[i]\n", " obj.position = positions[i]\n", - " logging.info(\" Added %s at %s\", obj.asset_id, obj.position)\n", - "\n" + " logging.info(\" Added %s at %s\", obj.asset_id, obj.position)" ] }, { @@ -563,14 +577,15 @@ ], "source": [ "if FLAGS.save_state:\n", - " logging.info(\"Saving the simulator state to '%s' prior to the simulation.\",\n", - " output_dir / \"scene.bullet\")\n", - " simulator.save_state(output_dir / \"scene.bullet\")\n", + " logging.info(\n", + " \"Saving the simulator state to '%s' prior to the simulation.\",\n", + " output_dir / \"scene.bullet\",\n", + " )\n", + " simulator.save_state(output_dir / \"scene.bullet\")\n", "\n", "# Run dynamic objects simulation\n", "logging.info(\"Running the simulation ...\")\n", - "animation, collisions = simulator.run(frame_start=0,\n", - " frame_end=scene.frame_end+1)" + "animation, collisions = simulator.run(frame_start=0, frame_end=scene.frame_end + 1)" ] }, { @@ -5631,9 +5646,8 @@ "source": [ "# --- Rendering\n", "if FLAGS.save_state:\n", - " logging.info(\"Saving the renderer state to '%s' \",\n", - " output_dir / \"scene.blend\")\n", - " renderer.save_state(output_dir / \"scene.blend\")\n", + " logging.info(\"Saving the renderer state to '%s' \", output_dir / \"scene.blend\")\n", + " renderer.save_state(output_dir / \"scene.blend\")\n", "\n", "\n", "logging.info(\"Rendering the scene ...\")\n", @@ -5641,23 +5655,25 @@ "\n", "# --- Postprocessing\n", "kb.compute_visibility(data_stack[\"segmentation\"], scene.assets)\n", - "visible_foreground_assets = [asset for asset in scene.foreground_assets\n", - " if np.max(asset.metadata[\"visibility\"]) > 0]\n", + "visible_foreground_assets = [\n", + " asset\n", + " for asset in scene.foreground_assets\n", + " if np.max(asset.metadata[\"visibility\"]) > 0\n", + "]\n", "visible_foreground_assets = sorted( # sort assets by their visibility\n", " visible_foreground_assets,\n", " key=lambda asset: np.sum(asset.metadata[\"visibility\"]),\n", - " reverse=True)\n", + " reverse=True,\n", + ")\n", "\n", "data_stack[\"segmentation\"] = kb.adjust_segmentation_idxs(\n", - " data_stack[\"segmentation\"],\n", - " scene.assets,\n", - " visible_foreground_assets)\n", + " data_stack[\"segmentation\"], scene.assets, visible_foreground_assets\n", + ")\n", "scene.metadata[\"num_instances\"] = len(visible_foreground_assets)\n", "\n", "# Save to image files\n", "kb.write_image_dict(data_stack, output_dir)\n", - "kb.post_processing.compute_bboxes(data_stack[\"segmentation\"],\n", - " visible_foreground_assets)\n", + "kb.post_processing.compute_bboxes(data_stack[\"segmentation\"], visible_foreground_assets)\n", "\n", "# --- Metadata\n", "logging.info(\"Collecting and storing metadata for each object.\")\n", @@ -5668,12 +5684,16 @@ " \"instances\": kb.get_instance_info(scene, visible_foreground_assets),\n", "}\n", "kb.write_json(filename=output_dir / \"metadata.json\", data=metadata)\n", - "kb.write_json(filename=output_dir / \"events.json\", data={\n", - " \"collisions\": kb.process_collisions(\n", - " collisions, scene, assets_subset=visible_foreground_assets),\n", - "})\n", + "kb.write_json(\n", + " filename=output_dir / \"events.json\",\n", + " data={\n", + " \"collisions\": kb.process_collisions(\n", + " collisions, scene, assets_subset=visible_foreground_assets\n", + " ),\n", + " },\n", + ")\n", "\n", - "kb.done()\n" + "kb.done()" ] }, { @@ -5701,12 +5721,12 @@ "n = FLAGS.frame_end - FLAGS.frame_start\n", "\n", "image_files = []\n", - "path = Path(\"output\") \n", - "fname_gif = path/\"output.gif\"\n", + "path = Path(\"output\")\n", + "fname_gif = path / \"output.gif\"\n", "for i in range(n):\n", " id = f\"{i:0>5.0f}\"\n", - " fname = path/f\"rgba_{id}.png\" \n", - " # fname = path/f\"depth_{id}.tiff\" \n", + " fname = path / f\"rgba_{id}.png\"\n", + " # fname = path/f\"depth_{id}.tiff\"\n", " image_files.append(fname)\n", "\n", "images = [Image.open(image) for image in image_files]\n", @@ -5772,9 +5792,9 @@ "print(data_stack.keys())\n", "\n", "(\n", - "data_stack[\"depth\"][0].shape,\n", - "data_stack[\"normal\"][0].shape,\n", - "metadata[\"camera\"],\n", + " data_stack[\"depth\"][0].shape,\n", + " data_stack[\"normal\"][0].shape,\n", + " metadata[\"camera\"],\n", ")" ] }, @@ -5808,7 +5828,7 @@ "import matplotlib.pyplot as plt\n", "\n", "t = 0\n", - "plt.imshow(data_stack[\"depth\"][t][...,0], vmin=7, vmax=12.)\n", + "plt.imshow(data_stack[\"depth\"][t][..., 0], vmin=7, vmax=12.0)\n", "# plt.imshow(data_stack[\"depth\"][0][...,0])\n", "plt.colorbar()" ] @@ -5870,7 +5890,7 @@ } ], "source": [ - "metadata['camera']['focal_length']" + "metadata[\"camera\"][\"focal_length\"]" ] }, { @@ -5890,7 +5910,7 @@ } ], "source": [ - "256/32" + "256 / 32" ] }, { @@ -5913,7 +5933,7 @@ ], "source": [ "K = metadata[\"camera\"][\"K\"]\n", - "K\n" + "K" ] }, { @@ -5941,7 +5961,7 @@ } ], "source": [ - "K, np.linalg.inv(K),np.linalg.inv(K.T)==np.linalg.inv(K).T" + "K, np.linalg.inv(K), np.linalg.inv(K.T) == np.linalg.inv(K).T" ] }, { @@ -5952,7 +5972,7 @@ "source": [ "def unproject_depth(depth, intrinsics):\n", " \"\"\"Unprojects a depth image into a point cloud.\n", - " \n", + "\n", " Args:\n", " depth (jnp.ndarray): The depth image. Shape (H, W)\n", " intrinsics (b.camera.Intrinsics): The camera intrinsics.\n", @@ -5994,7 +6014,7 @@ } ], "source": [ - "np.ones((2,3,4)).T.shape" + "np.ones((2, 3, 4)).T.shape" ] }, { @@ -6057,6 +6077,7 @@ "t = 0\n", "n = 256\n", "\n", + "\n", "# Kubric works with \"normalized\" image coordinates $u,v \\in [0,1]$\n", "# The intrinsic matrix is given by\n", "# ```\n", @@ -6064,7 +6085,7 @@ "# [ 0 -f/H -0.5 ]\n", "# [ 0 0 -1.0 ]]\n", "# ```\n", - "# where W,H are the sensor width and height and f is the focal length. \n", + "# where W,H are the sensor width and height and f is the focal length.\n", "# Image coordinates (u,v) and Camera coordinates (x,y,z) are related as follows:\n", "# ```\n", "# -z * (u,v,1)^T = K * (x,y,z)^T\n", @@ -6072,23 +6093,25 @@ "# ...\n", "# > https://github.com/google-research/kubric/blob/main/kubric/core/cameras.py#L23\n", "def z_to_xyz(z, img_shape, K):\n", - " h, w, _ = z.shape[-3:] \n", - " uv = np.stack(np.meshgrid(np.linspace(1,0, w), np.linspace(0,1, h), indexing=\"ij\"), axis=0)\n", - " uv1 = np.concatenate([uv, np.ones((1,w,h))], axis=0)\n", - " z = data_stack[\"depth\"][t].T\n", - " xyz = np.linalg.inv(K)@(-z*uv1).reshape(3,-1)\n", - " return xyz.reshape(3,w,h).T\n", + " h, w, _ = z.shape[-3:]\n", + " uv = np.stack(\n", + " np.meshgrid(np.linspace(1, 0, w), np.linspace(0, 1, h), indexing=\"ij\"), axis=0\n", + " )\n", + " uv1 = np.concatenate([uv, np.ones((1, w, h))], axis=0)\n", + " z = data_stack[\"depth\"][t].T\n", + " xyz = np.linalg.inv(K) @ (-z * uv1).reshape(3, -1)\n", + " return xyz.reshape(3, w, h).T\n", + "\n", "\n", "# uv = 1.*np.stack(np.meshgrid(np.linspace(0,1, n), np.linspace(0, 1, n)), axis=-1).reshape(-1,2)\n", "# uv1 = np.concatenate([uv, np.ones((n**2,1))], axis=-1)\n", "# z = data_stack[\"depth\"][t].reshape(-1,1)\n", "# xyz = (-z*uv1)@(np.linalg.inv(K).T)\n", - "# xyz = xyz.reshape(n,n,3) \n", - "\n", + "# xyz = xyz.reshape(n,n,3)\n", "\n", "\n", - "im = z_to_xyz(data_stack[\"depth\"][t], (n,n), K)\n", - "plt.matshow(im[...,2])\n", + "im = z_to_xyz(data_stack[\"depth\"][t], (n, n), K)\n", + "plt.matshow(im[..., 2])\n", "plt.matshow(data_stack[\"depth\"][t])\n", "plt.colorbar()" ] @@ -6111,13 +6134,13 @@ ], "source": [ "plt.gca().set_aspect(1)\n", - "plt.ylim(6,15)\n", - "plt.xlim(-8,8)\n", - "for i in np.arange(0,150, step=1):\n", - " plt.scatter(im[i,:,1], im[i,:,2], s= 1, alpha=0.2, c=\"r\")\n", + "plt.ylim(6, 15)\n", + "plt.xlim(-8, 8)\n", + "for i in np.arange(0, 150, step=1):\n", + " plt.scatter(im[i, :, 1], im[i, :, 2], s=1, alpha=0.2, c=\"r\")\n", "\n", - "for i in np.arange(150,256, step=1):\n", - " plt.scatter(im[i,:,1], im[i,:,2], s= 1, alpha=0.2, c=\"b\")" + "for i in np.arange(150, 256, step=1):\n", + " plt.scatter(im[i, :, 1], im[i, :, 2], s=1, alpha=0.2, c=\"b\")" ] }, { diff --git a/scripts/experiments/collaborations/aryan.ipynb b/scripts/experiments/collaborations/aryan.ipynb index e318da4c..deeb45c5 100644 --- a/scripts/experiments/collaborations/aryan.ipynb +++ b/scripts/experiments/collaborations/aryan.ipynb @@ -53,26 +53,25 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "paths = glob.glob(\n", - " \"*.pkl\"\n", - ")\n", + "paths = glob.glob(\"*.pkl\")\n", "all_data = pickle.load(open(paths[0], \"rb\"))\n", "IDX = 1\n", "data = all_data[IDX]\n", "\n", "print(data[\"camera_image\"].keys())\n", - "K = data[\"camera_image\"]['camera_matrix'][0]\n", - "rgb = data[\"camera_image\"]['rgbPixels']\n", - "depth = data[\"camera_image\"]['depthPixels']\n", - "camera_pose = data[\"camera_image\"]['camera_pose']\n", + "K = data[\"camera_image\"][\"camera_matrix\"][0]\n", + "rgb = data[\"camera_image\"][\"rgbPixels\"]\n", + "depth = data[\"camera_image\"][\"depthPixels\"]\n", + "camera_pose = data[\"camera_image\"][\"camera_pose\"]\n", "camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)\n", - "fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]\n", - "h,w = depth.shape\n", + "fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n", + "h, w = depth.shape\n", "near = 0.001\n", "far = 5.0\n", "depth[depth < near] = far\n", - "rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,near,far))\n", + "rgbd_original = b.RGBD(\n", + " rgb, depth, camera_pose, b.Intrinsics(h, w, fx, fy, cx, cy, near, far)\n", + ")\n", "b.get_rgb_image(rgbd_original.rgb)" ] }, @@ -102,11 +101,20 @@ "outputs": [], "source": [ "table_pose, table_dims = b.utils.infer_table_plane(\n", - " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics), jnp.eye(4), rgbd_scaled_down.intrinsics,\n", - " ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=0.1\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics),\n", + " jnp.eye(4),\n", + " rgbd_scaled_down.intrinsics,\n", + " ransac_threshold=0.001,\n", + " inlier_threshold=0.001,\n", + " segmentation_threshold=0.1,\n", ")\n", "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", table_pose)\n", "table_mesh = b.utils.make_cuboid_mesh(table_dims)\n", "b.show_trimesh(\"table_mesh\", table_mesh)\n", @@ -120,7 +128,10 @@ "outputs": [], "source": [ "b.setup_renderer(rgbd_scaled_down.intrinsics)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),\"sample_objs/sphere.obj\"), scaling_factor=1.0/30.0)\n", + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/sphere.obj\"),\n", + " scaling_factor=1.0 / 30.0,\n", + ")\n", "b.show_trimesh(\"sphere_mesh\", b.RENDERER.meshes[0], color=(0.0, 1.0, 0.0))" ] }, @@ -131,9 +142,16 @@ "outputs": [], "source": [ "def _contact_parameters_to_pose(cp, table_pose):\n", - " return table_pose @ b.scene_graph.relative_pose_from_edge(cp, 2, b.RENDERER.model_box_dims[0])\n", + " return table_pose @ b.scene_graph.relative_pose_from_edge(\n", + " cp, 2, b.RENDERER.model_box_dims[0]\n", + " )\n", + "\n", + "\n", "contact_parameters_to_pose = jax.jit(_contact_parameters_to_pose)\n", - "contact_parameters_to_pose_vmap = jax.jit(jax.vmap(_contact_parameters_to_pose, in_axes=(0,None)))\n", + "contact_parameters_to_pose_vmap = jax.jit(\n", + " jax.vmap(_contact_parameters_to_pose, in_axes=(0, None))\n", + ")\n", + "\n", "\n", "def _compute_likelihood(rendered_depth):\n", " return b.threedp3_likelihood_old(\n", @@ -143,8 +161,10 @@ " 0.00001,\n", " 1.0,\n", " 1.0,\n", - " 3\n", + " 3,\n", " )\n", + "\n", + "\n", "compute_likelihood = jax.jit(_compute_likelihood)\n", "compute_likelihood_vmap = jax.jit(jax.vmap(_compute_likelihood))" ] @@ -156,20 +176,32 @@ "outputs": [], "source": [ "contact_parameter_grid = b.utils.make_translation_grid_enumeration_3d(\n", - " -table_dims[0]/2.0, -table_dims[1]/2.0, 0.0,\n", - " table_dims[0]/2.0, table_dims[1]/2.0, 0.0,\n", - " 50, 50, 1\n", + " -table_dims[0] / 2.0,\n", + " -table_dims[1] / 2.0,\n", + " 0.0,\n", + " table_dims[0] / 2.0,\n", + " table_dims[1] / 2.0,\n", + " 0.0,\n", + " 50,\n", + " 50,\n", + " 1,\n", ")\n", "orange_poses_full = contact_parameters_to_pose_vmap(contact_parameter_grid, table_pose)\n", "\n", "# for (i,p) in enumerate(orange_poses_full):\n", "# b.show_pose(f\"{i}\", p)\n", "\n", - "rendered_depth_orange_alone_all = b.RENDERER.render_many(orange_poses_full[:,None,...], jnp.array([0]))[...,2]\n", - "poses_in_field_of_view = (rendered_depth_orange_alone_all < b.RENDERER.intrinsics.far).any(-1).any(-1)\n", + "rendered_depth_orange_alone_all = b.RENDERER.render_many(\n", + " orange_poses_full[:, None, ...], jnp.array([0])\n", + ")[..., 2]\n", + "poses_in_field_of_view = (\n", + " (rendered_depth_orange_alone_all < b.RENDERER.intrinsics.far).any(-1).any(-1)\n", + ")\n", "orange_poses = orange_poses_full[poses_in_field_of_view]\n", "rendered_depth_orange_alone = rendered_depth_orange_alone_all[poses_in_field_of_view]\n", - "rendered_depth_spliced = jnp.minimum(rendered_depth_orange_alone, rgbd_scaled_down.depth[None, :,:])\n", + "rendered_depth_spliced = jnp.minimum(\n", + " rendered_depth_orange_alone, rgbd_scaled_down.depth[None, :, :]\n", + ")\n", "scores = compute_likelihood_vmap(rendered_depth_spliced)" ] }, @@ -190,10 +222,12 @@ "source": [ "IDX = 100\n", "b.show_pose(\"candidate_pose\", orange_poses[IDX])\n", - "b.hstack_images([\n", - " b.get_depth_image(rendered_depth_orange_alone[IDX]),\n", - " b.get_depth_image(rendered_depth_spliced[IDX])\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(rendered_depth_orange_alone[IDX]),\n", + " b.get_depth_image(rendered_depth_spliced[IDX]),\n", + " ]\n", + ")" ] }, { @@ -203,7 +237,12 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", table_pose)\n", "b.show_trimesh(\"sphere_mesh\", b.RENDERER.meshes[0], color=(0.0, 1.0, 0.0))" ] @@ -223,15 +262,17 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.split(key,2)[0]\n", + "key = jax.random.split(key, 2)[0]\n", "sampled_indices = jax.random.categorical(key, scores, shape=(2500,))\n", "sampled_poses = orange_poses[sampled_indices]\n", "idx = sampled_indices[0]\n", "b.set_pose(\"sphere_mesh\", orange_poses[idx])\n", - "b.hstack_images([\n", - " b.get_depth_image(rendered_depth_orange_alone[idx]),\n", - " b.get_depth_image(rendered_depth_spliced[idx])\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(rendered_depth_orange_alone[idx]),\n", + " b.get_depth_image(rendered_depth_spliced[idx]),\n", + " ]\n", + ")" ] }, { @@ -240,7 +281,9 @@ "metadata": {}, "outputs": [], "source": [ - "overlay_img = b.RENDERER.render(sampled_poses, jnp.full(sampled_poses.shape[0], 0))[...,2]\n", + "overlay_img = b.RENDERER.render(sampled_poses, jnp.full(sampled_poses.shape[0], 0))[\n", + " ..., 2\n", + "]\n", "b.get_depth_image(overlay_img)" ] } diff --git a/scripts/experiments/collaborations/ben.ipynb b/scripts/experiments/collaborations/ben.ipynb index fd95d4ee..d5002d19 100644 --- a/scripts/experiments/collaborations/ben.ipynb +++ b/scripts/experiments/collaborations/ben.ipynb @@ -64,7 +64,7 @@ ], "source": [ "data = np.load(\"test_scenes/1.npz\")\n", - "rgb = jnp.array(data[\"rgb\"])\n", + "rgb = jnp.array(data[\"rgb\"])\n", "gt_ids = jnp.array(data[\"gt_ids\"]).astype(jnp.int32)\n", "gt_poses = jnp.array(data[\"gt_poses\"])\n", "gt_poses_offset = jnp.array(data[\"gt_poses_offset\"])\n", @@ -108,12 +108,17 @@ ], "source": [ "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=jnp.array([2.0, 2.0, 0.02]))" + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=jnp.array([2.0, 2.0, 0.02]),\n", + ")" ] }, { @@ -122,12 +127,14 @@ "metadata": {}, "outputs": [], "source": [ - "poses = jnp.concatenate([pred_poses_cam_frame, b.transform_from_pos(jnp.array([0.0, 0.0, -0.1]))[None,...] ])\n", + "poses = jnp.concatenate(\n", + " [pred_poses_cam_frame, b.transform_from_pos(jnp.array([0.0, 0.0, -0.1]))[None, ...]]\n", + ")\n", "ids = jnp.concatenate([pred_ids, jnp.array([21])])\n", "b.clear()\n", "colors = b.distinct_colors(len(poses))\n", "for i in range(len(poses)):\n", - " b.show_trimesh(f\"{i}\", b.RENDERER.meshes[ids[i]],color=colors[i])\n", + " b.show_trimesh(f\"{i}\", b.RENDERER.meshes[ids[i]], color=colors[i])\n", " b.set_pose(f\"{i}\", poses[i])" ] }, @@ -150,7 +157,7 @@ ], "source": [ "img = b.RENDERER.render(b.inverse_pose(cam_pose_cv2) @ poses, ids)\n", - "depth = b.get_depth_image(img[...,2])\n", + "depth = b.get_depth_image(img[..., 2])\n", "depth" ] }, @@ -176,18 +183,34 @@ } ], "source": [ - "def get_approximating_contact(object_idx_1, object_idx_2, face_1,face_2):\n", - " contact_params ,slack = b.scene_graph.closest_approximate_contact_params(\n", - " poses[object_idx_1] @ b.scene_graph.get_contact_planes(b.RENDERER.model_box_dims[ids[object_idx_1]])[face_1],\n", - " poses[object_idx_2] @ b.scene_graph.get_contact_planes(b.RENDERER.model_box_dims[ids[object_idx_2]])[face_2]\n", + "def get_approximating_contact(object_idx_1, object_idx_2, face_1, face_2):\n", + " contact_params, slack = b.scene_graph.closest_approximate_contact_params(\n", + " poses[object_idx_1]\n", + " @ b.scene_graph.get_contact_planes(\n", + " b.RENDERER.model_box_dims[ids[object_idx_1]]\n", + " )[face_1],\n", + " poses[object_idx_2]\n", + " @ b.scene_graph.get_contact_planes(\n", + " b.RENDERER.model_box_dims[ids[object_idx_2]]\n", + " )[face_2],\n", " )\n", - " dimensions_on_contact_plane = b.scene_graph.get_contact_plane_dimenions(b.RENDERER.model_box_dims[ids[object_idx_1]])[face_1]\n", - " valid = (object_idx_1!=object_idx_2) * jnp.all(jnp.abs(contact_params[:2]) < (dimensions_on_contact_plane / 2.0))\n", - " return jax.lax.select(valid, contact_params , jnp.full((3,), jnp.inf)) , jax.lax.select(valid, slack, jnp.full((4,4), jnp.inf))\n", + " dimensions_on_contact_plane = b.scene_graph.get_contact_plane_dimenions(\n", + " b.RENDERER.model_box_dims[ids[object_idx_1]]\n", + " )[face_1]\n", + " valid = (object_idx_1 != object_idx_2) * jnp.all(\n", + " jnp.abs(contact_params[:2]) < (dimensions_on_contact_plane / 2.0)\n", + " )\n", + " return jax.lax.select(\n", + " valid, contact_params, jnp.full((3,), jnp.inf)\n", + " ), jax.lax.select(valid, slack, jnp.full((4, 4), jnp.inf))\n", + "\n", + "\n", "get_approximating_contact_vmap = jax.jit(b.utils.multivmap(get_approximating_contact))\n", "\n", "\n", - "inferred_contact_params, slacks = get_approximating_contact_vmap(jnp.arange(len(poses)), jnp.arange(len(poses)), jnp.arange(6), jnp.arange(6))\n", + "inferred_contact_params, slacks = get_approximating_contact_vmap(\n", + " jnp.arange(len(poses)), jnp.arange(len(poses)), jnp.arange(6), jnp.arange(6)\n", + ")\n", "print(inferred_contact_params.shape)\n", "print(slacks.shape)" ] @@ -198,13 +221,17 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "position_errors = jnp.abs(slacks[...,:3,3]).sum(-1)\n", + "position_errors = jnp.abs(slacks[..., :3, 3]).sum(-1)\n", "\n", "from jax.scipy.spatial.transform import Rotation\n", - "rotation_matrix_to_angle_error = lambda pose : jnp.linalg.norm(Rotation.from_matrix(pose[:3,:3]).as_rotvec())\n", "\n", - "angle_errors = jnp.vectorize(rotation_matrix_to_angle_error, signature=\"(3,3)->()\")(slacks[...,:3,:3])\n", + "rotation_matrix_to_angle_error = lambda pose: jnp.linalg.norm(\n", + " Rotation.from_matrix(pose[:3, :3]).as_rotvec()\n", + ")\n", + "\n", + "angle_errors = jnp.vectorize(rotation_matrix_to_angle_error, signature=\"(3,3)->()\")(\n", + " slacks[..., :3, :3]\n", + ")\n", "assert position_errors.shape == angle_errors.shape" ] }, @@ -220,10 +247,12 @@ "box_dims = b.RENDERER.model_box_dims[ids]\n", "parents = jnp.full((len(poses),), -1)\n", "node_names = NAMES[ids]\n", - "contact_params = jnp.full((len(poses),3), 0.0)\n", + "contact_params = jnp.full((len(poses), 3), 0.0)\n", "face_parents = jnp.full((len(poses),), -1)\n", "face_children = jnp.full((len(poses),), -1)\n", - "sg = b.scene_graph.SceneGraph(root_poses, box_dims, parents, contact_params, face_parents, face_children)\n", + "sg = b.scene_graph.SceneGraph(\n", + " root_poses, box_dims, parents, contact_params, face_parents, face_children\n", + ")\n", "sg.visualize(\"test.png\", node_names=node_names)" ] }, @@ -251,12 +280,20 @@ " (child_idx, face_1, face_2) = jnp.unravel_index(i, position_errors[TABLE_IDX].shape)\n", " if jnp.all(parents[child_idx] != -1):\n", " continue\n", - " print(position_errors[TABLE_IDX, child_idx, face_1, face_2], angle_errors[TABLE_IDX, child_idx, face_1, face_2])\n", - " if position_errors[TABLE_IDX, child_idx, face_1, face_2] < 0.01 and angle_errors[TABLE_IDX, child_idx, face_1, face_2] < 0.03:\n", + " print(\n", + " position_errors[TABLE_IDX, child_idx, face_1, face_2],\n", + " angle_errors[TABLE_IDX, child_idx, face_1, face_2],\n", + " )\n", + " if (\n", + " position_errors[TABLE_IDX, child_idx, face_1, face_2] < 0.01\n", + " and angle_errors[TABLE_IDX, child_idx, face_1, face_2] < 0.03\n", + " ):\n", " parents = parents.at[child_idx].set(TABLE_IDX)\n", " face_parents = face_parents.at[child_idx].set(face_1)\n", " face_children = face_children.at[child_idx].set(face_2)\n", - " contact_params = contact_params.at[child_idx].set(inferred_contact_params[TABLE_IDX, child_idx, face_1, face_2])\n", + " contact_params = contact_params.at[child_idx].set(\n", + " inferred_contact_params[TABLE_IDX, child_idx, face_1, face_2]\n", + " )\n", " else:\n", " break" ] @@ -279,15 +316,25 @@ "source": [ "sort_order = jnp.argsort(position_errors.reshape(-1))\n", "for i in sort_order:\n", - " (parent_idx, child_idx, face_1, face_2) = jnp.unravel_index(i, position_errors.shape)\n", + " (parent_idx, child_idx, face_1, face_2) = jnp.unravel_index(\n", + " i, position_errors.shape\n", + " )\n", " if jnp.all(parents[child_idx] != -1):\n", " continue\n", - " print(position_errors[parent_idx, child_idx, face_1, face_2], angle_errors[parent_idx, child_idx, face_1, face_2])\n", - " if position_errors[parent_idx, child_idx, face_1, face_2] < 0.01 and angle_errors[parent_idx, child_idx, face_1, face_2] < 0.09:\n", + " print(\n", + " position_errors[parent_idx, child_idx, face_1, face_2],\n", + " angle_errors[parent_idx, child_idx, face_1, face_2],\n", + " )\n", + " if (\n", + " position_errors[parent_idx, child_idx, face_1, face_2] < 0.01\n", + " and angle_errors[parent_idx, child_idx, face_1, face_2] < 0.09\n", + " ):\n", " parents = parents.at[child_idx].set(parent_idx)\n", " face_parents = face_parents.at[child_idx].set(face_1)\n", " face_children = face_children.at[child_idx].set(face_2)\n", - " contact_params = contact_params.at[child_idx].set(inferred_contact_params[parent_idx, child_idx, face_1, face_2])\n", + " contact_params = contact_params.at[child_idx].set(\n", + " inferred_contact_params[parent_idx, child_idx, face_1, face_2]\n", + " )\n", " else:\n", " break" ] @@ -298,7 +345,9 @@ "metadata": {}, "outputs": [], "source": [ - "sg = b.scene_graph.SceneGraph(root_poses, box_dims, parents, contact_params, face_parents, face_children)\n", + "sg = b.scene_graph.SceneGraph(\n", + " root_poses, box_dims, parents, contact_params, face_parents, face_children\n", + ")\n", "sg.visualize(\"test.png\", node_names=node_names)" ] }, @@ -343,7 +392,7 @@ "source": [ "sg_poses = sg.get_poses()\n", "img = b.RENDERER.render(b.inverse_pose(cam_pose_cv2) @ sg_poses, ids)\n", - "depth = b.get_depth_image(img[...,2])\n", + "depth = b.get_depth_image(img[..., 2])\n", "depth" ] }, diff --git a/scripts/experiments/collaborations/single_object_model_mccoy.ipynb b/scripts/experiments/collaborations/single_object_model_mccoy.ipynb index 2e238060..ae39fd2a 100644 --- a/scripts/experiments/collaborations/single_object_model_mccoy.ipynb +++ b/scripts/experiments/collaborations/single_object_model_mccoy.ipynb @@ -14,6 +14,7 @@ "import jax\n", "import os\n", "from tqdm import tqdm\n", + "\n", "console = genjax.pretty(show_locals=False)\n", "from genjax._src.core.transforms.incremental import NoChange\n", "from genjax._src.core.transforms.incremental import UnknownChange\n", @@ -38,19 +39,17 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=500.0, fy=500.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=20.0\n", + " height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0\n", ")\n", "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)" ] @@ -76,14 +75,18 @@ "source": [ "@genjax.gen\n", "def single_object_model():\n", - " pose = b.genjax.uniform_pose(jnp.array([-0.01,-0.01,1.5]), jnp.array([0.01,0.01,3.5])) @ \"pose\"\n", + " pose = (\n", + " b.genjax.uniform_pose(\n", + " jnp.array([-0.01, -0.01, 1.5]), jnp.array([0.01, 0.01, 3.5])\n", + " )\n", + " @ \"pose\"\n", + " )\n", " obj_id = 0\n", - " rendered = b.RENDERER.render(\n", - " pose[None,...] , jnp.array([obj_id])\n", - " )[...,:3]\n", + " rendered = b.RENDERER.render(pose[None, ...], jnp.array([obj_id]))[..., :3]\n", " image = b.genjax.image_likelihood(rendered, 0.01, 0.01, 1.0) @ \"image\"\n", " return rendered\n", "\n", + "\n", "importance_jit = jax.jit(single_object_model.importance)\n", "key = jax.random.PRNGKey(5)" ] @@ -95,9 +98,9 @@ "metadata": {}, "outputs": [], "source": [ - "key, (_,gt_trace) = importance_jit(key, genjax.choice_map({}), ())\n", + "key, (_, gt_trace) = importance_jit(key, genjax.choice_map({}), ())\n", "print(gt_trace.get_score())\n", - "b.get_depth_image(gt_trace[\"image\"][...,2])" + "b.get_depth_image(gt_trace[\"image\"][..., 2])" ] }, { @@ -107,7 +110,9 @@ "metadata": {}, "outputs": [], "source": [ - "importance_parallel = jax.jit(jax.vmap(single_object_model.importance, in_axes=(0, None, None)))" + "importance_parallel = jax.jit(\n", + " jax.vmap(single_object_model.importance, in_axes=(0, None, None))\n", + ")" ] }, { @@ -118,7 +123,9 @@ "outputs": [], "source": [ "keys = jax.random.split(key, 1000)\n", - "keys, (weights, traces) = importance_parallel(keys, genjax.choice_map({\"image\": gt_trace[\"image\"]}), ());" + "keys, (weights, traces) = importance_parallel(\n", + " keys, genjax.choice_map({\"image\": gt_trace[\"image\"]}), ()\n", + ")" ] }, { @@ -141,7 +148,9 @@ "sampled_indices = jax.random.categorical(key, weights, shape=(10,))\n", "print(sampled_indices)\n", "print(weights[sampled_indices])\n", - "images = [b.get_depth_image(img[:,:,2]) for img in traces.get_retval()[sampled_indices]]" + "images = [\n", + " b.get_depth_image(img[:, :, 2]) for img in traces.get_retval()[sampled_indices]\n", + "]" ] }, { @@ -151,7 +160,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.multi_panel(images,title=\"10 Posterior Samples\", title_fontsize=20).convert(\"RGB\")" + "b.multi_panel(images, title=\"10 Posterior Samples\", title_fontsize=20).convert(\"RGB\")" ] }, { @@ -190,11 +199,20 @@ "def importance_sampling_with_proposal(key, trace, variance, concentration):\n", " pose_mean = b.transform_from_pos(jnp.array([0.0, 0.0, 1.0]))\n", " pose = b.distributions.gaussian_vmf_jit(key, pose_mean, variance, concentration)\n", - " proposal_weight = b.distributions.gaussian_vmf_logpdf_jit(pose, pose_mean, variance, concentration)\n", - " new_trace = trace.update(key, genjax.choice_map({\"root_pose_0\": pose}), \n", - " b.genjax.make_unknown_change_argdiffs(trace))[1][2]\n", - " return new_trace,new_trace.get_score() - proposal_weight\n", - "importance_sampling_with_proposal_vmap = jax.vmap(importance_sampling_with_proposal, in_axes=(0, None, None, None))" + " proposal_weight = b.distributions.gaussian_vmf_logpdf_jit(\n", + " pose, pose_mean, variance, concentration\n", + " )\n", + " new_trace = trace.update(\n", + " key,\n", + " genjax.choice_map({\"root_pose_0\": pose}),\n", + " b.genjax.make_unknown_change_argdiffs(trace),\n", + " )[1][2]\n", + " return new_trace, new_trace.get_score() - proposal_weight\n", + "\n", + "\n", + "importance_sampling_with_proposal_vmap = jax.vmap(\n", + " importance_sampling_with_proposal, in_axes=(0, None, None, None)\n", + ")" ] }, { @@ -204,7 +222,9 @@ "metadata": {}, "outputs": [], "source": [ - "traces, weights = importance_sampling_with_proposal_vmap(jax.random.split(key, 100), gt_trace, 0.001, 0.001)" + "traces, weights = importance_sampling_with_proposal_vmap(\n", + " jax.random.split(key, 100), gt_trace, 0.001, 0.001\n", + ")" ] }, { @@ -217,8 +237,11 @@ "sampled_indices = jax.random.categorical(key, weights, shape=(10,))\n", "print(sampled_indices)\n", "print(weights[sampled_indices])\n", - "images = [b.get_depth_image(img[:,:,2]) for img in b.genjax.get_rendered_image(traces)[sampled_indices]]\n", - "b.multi_panel(images,title=\"10 Posterior Samples\", title_fontsize=20).convert(\"RGB\")" + "images = [\n", + " b.get_depth_image(img[:, :, 2])\n", + " for img in b.genjax.get_rendered_image(traces)[sampled_indices]\n", + "]\n", + "b.multi_panel(images, title=\"10 Posterior Samples\", title_fontsize=20).convert(\"RGB\")" ] }, { diff --git a/scripts/experiments/collaborations/xuan.ipynb b/scripts/experiments/collaborations/xuan.ipynb index d22af66f..6dd422ab 100644 --- a/scripts/experiments/collaborations/xuan.ipynb +++ b/scripts/experiments/collaborations/xuan.ipynb @@ -38,11 +38,7 @@ "outputs": [], "source": [ "original_intrinsics = b.Intrinsics(\n", - " height=500,\n", - " width=500,\n", - " fx=500.0, fy=500.0,\n", - " cx=250.0, cy=250.0,\n", - " near=0.001, far=6.0\n", + " height=500, width=500, fx=500.0, fy=500.0, cx=250.0, cy=250.0, near=0.001, far=6.0\n", ")\n", "\n", "meshes = []\n", @@ -73,11 +69,13 @@ }, "outputs": [], "source": [ - "contact_plane = b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 1.5, 1.0]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - "))\n", + "contact_plane = b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 1.5, 1.0]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + ")\n", "\n", "contact_poses_parallel_jit = jax.jit(\n", " jax.vmap(\n", @@ -96,77 +94,81 @@ "\n", "distinct_colors = b.distinct_colors(3)\n", "ids = jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0])\n", - "color = jnp.array([0, 1, 2, 0, 1, 2,2,0,1])\n", + "color = jnp.array([0, 1, 2, 0, 1, 2, 2, 0, 1])\n", "\n", "# Frame 1\n", - "all_contact_params = jnp.array([\n", - "[\n", - " [-0.3, -0.3, 0.0],\n", - " [0.3, -0.3, 0.0],\n", - " [0.3, 0.3, 0.0],\n", - " [-0.15, 0.2, 0.0],\n", - " [-0.3, 0.2, 0.0],\n", - " [-0.45, 0.2, 0.0],\n", - " [-0.15, 0.45, 0.0],\n", - " [-0.3, 0.45,0.0],\n", - " [-0.45, 0.45, 0.0],\n", - "],\n", - "# Frame 2\n", - "[\n", - " [-0.3, -0.3, 0.0],\n", - " [0.3, -0.3, 0.0],\n", - " [0.3, 0.3, 0.0],\n", - " [-0.3, -0.3, 0.0],\n", - " [-0.3, 0.2, 0.0],\n", - " [-0.45, 0.2, 0.0],\n", - " [-0.15, 0.45, 0.0],\n", - " [-0.3, 0.45,0.0],\n", - " [-0.45, 0.45, 0.0],\n", - "],\n", - "[\n", - " [-0.3, -0.3, 0.0],\n", - " [0.3, -0.3, 0.0],\n", - " [0.3, 0.3, 0.0],\n", - " [-0.3, -0.3, 0.0],\n", - " [0.3, -0.3, 0.0],\n", - " [-0.45, 0.2, 0.0],\n", - " [-0.15, 0.45, 0.0],\n", - " [-0.3, 0.45,0.0],\n", - " [-0.45, 0.45, 0.0],\n", - "],\n", - "[\n", - " [-0.3, -0.3, 0.0],\n", - " [0.3, -0.3, 0.0],\n", - " [0.3, 0.3, 0.0],\n", - " [-0.3, -0.3, 0.0],\n", - " [0.3, -0.3, 0.0],\n", - " [0.3, 0.3, 0.0],\n", - " [-0.15, 0.45, 0.0],\n", - " [-0.3, 0.45,0.0],\n", - " [-0.45, 0.45, 0.0],\n", - "]\n", - "\n", - "])\n", + "all_contact_params = jnp.array(\n", + " [\n", + " [\n", + " [-0.3, -0.3, 0.0],\n", + " [0.3, -0.3, 0.0],\n", + " [0.3, 0.3, 0.0],\n", + " [-0.15, 0.2, 0.0],\n", + " [-0.3, 0.2, 0.0],\n", + " [-0.45, 0.2, 0.0],\n", + " [-0.15, 0.45, 0.0],\n", + " [-0.3, 0.45, 0.0],\n", + " [-0.45, 0.45, 0.0],\n", + " ],\n", + " # Frame 2\n", + " [\n", + " [-0.3, -0.3, 0.0],\n", + " [0.3, -0.3, 0.0],\n", + " [0.3, 0.3, 0.0],\n", + " [-0.3, -0.3, 0.0],\n", + " [-0.3, 0.2, 0.0],\n", + " [-0.45, 0.2, 0.0],\n", + " [-0.15, 0.45, 0.0],\n", + " [-0.3, 0.45, 0.0],\n", + " [-0.45, 0.45, 0.0],\n", + " ],\n", + " [\n", + " [-0.3, -0.3, 0.0],\n", + " [0.3, -0.3, 0.0],\n", + " [0.3, 0.3, 0.0],\n", + " [-0.3, -0.3, 0.0],\n", + " [0.3, -0.3, 0.0],\n", + " [-0.45, 0.2, 0.0],\n", + " [-0.15, 0.45, 0.0],\n", + " [-0.3, 0.45, 0.0],\n", + " [-0.45, 0.45, 0.0],\n", + " ],\n", + " [\n", + " [-0.3, -0.3, 0.0],\n", + " [0.3, -0.3, 0.0],\n", + " [0.3, 0.3, 0.0],\n", + " [-0.3, -0.3, 0.0],\n", + " [0.3, -0.3, 0.0],\n", + " [0.3, 0.3, 0.0],\n", + " [-0.15, 0.45, 0.0],\n", + " [-0.3, 0.45, 0.0],\n", + " [-0.45, 0.45, 0.0],\n", + " ],\n", + " ]\n", + ")\n", "\n", "rgbd_images = []\n", "all_poses = []\n", "for i in range(len(all_contact_params)):\n", " contact_params = all_contact_params[i]\n", " poses = contact_plane @ contact_poses_parallel_jit(\n", - " contact_params,\n", - " 3,\n", - " b.RENDERER.model_box_dims[ids]\n", + " contact_params, 3, b.RENDERER.model_box_dims[ids]\n", " )\n", " all_poses.append(poses)\n", " viz.clear()\n", "\n", - " viz.make_trimesh(table_mesh, contact_plane, np.array([221, 174, 126, 255.0])/255.0)\n", + " viz.make_trimesh(\n", + " table_mesh, contact_plane, np.array([221, 174, 126, 255.0]) / 255.0\n", + " )\n", " for i in range(len(poses)):\n", - " viz.make_trimesh(b.RENDERER.meshes[ids[i]], poses[i], np.array([*distinct_colors[color[i]], 1.0]))\n", + " viz.make_trimesh(\n", + " b.RENDERER.meshes[ids[i]],\n", + " poses[i],\n", + " np.array([*distinct_colors[color[i]], 1.0]),\n", + " )\n", "\n", " rgbd = viz.capture_image(original_intrinsics, jnp.eye(4))\n", - " rgbd_images.append(rgbd)\n", - "\n" + " rgbd_images.append(rgbd)" ] }, { @@ -177,8 +179,8 @@ }, "outputs": [], "source": [ - "np.savez(\"rgbd.npz\",rgbd_images[0])\n", - "b.hstack_images([b.get_rgb_image(rgbd.rgb) for rgbd in rgbd_images])\n" + "np.savez(\"rgbd.npz\", rgbd_images[0])\n", + "b.hstack_images([b.get_rgb_image(rgbd.rgb) for rgbd in rgbd_images])" ] }, { @@ -189,7 +191,7 @@ }, "outputs": [], "source": [ - "rgbd_original = np.load(\"rgbd.npz\",allow_pickle=True)[\"arr_0\"].item()\n", + "rgbd_original = np.load(\"rgbd.npz\", allow_pickle=True)[\"arr_0\"].item()\n", "SCALING_FACTOR = 0.3\n", "rgbd = b.scale_rgbd(rgbd_original, SCALING_FACTOR)" ] @@ -227,11 +229,11 @@ "b.setup_renderer(intrinsics)\n", "for m in meshes:\n", " b.RENDERER.add_mesh(m)\n", - " \n", - "observed_point_cloud_image = b.RENDERER.render_multiobject(all_poses[0], ids)[:,:,:3]\n", + "\n", + "observed_point_cloud_image = b.RENDERER.render_multiobject(all_poses[0], ids)[:, :, :3]\n", "b.clear()\n", - "b.show_cloud(\"1\", observed_point_cloud_image[:,:,:3].reshape(-1,3))\n", - "b.get_depth_image(observed_point_cloud_image[:,:,2])\n" + "b.show_cloud(\"1\", observed_point_cloud_image[:, :, :3].reshape(-1, 3))\n", + "b.get_depth_image(observed_point_cloud_image[:, :, 2])" ] }, { @@ -243,18 +245,15 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.5, jnp.pi, (11,11,11)), (0.2, jnp.pi/3, (11,11,11)), (0.1, jnp.pi/5, (11,11,1)),\n", - " (0.05, jnp.pi/5, (11,11,11)), \n", + " (0.5, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi / 3, (11, 11, 11)),\n", + " (0.1, jnp.pi / 5, (11, 11, 1)),\n", + " (0.05, jnp.pi / 5, (11, 11, 11)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", - "]\n", - "\n" + " b.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", + "]" ] }, { @@ -265,12 +264,19 @@ }, "outputs": [], "source": [ - "threedp3_likelihood_full_hierarchical_bayes_per_pixel_jit = jax.jit(jax.vmap(jax.vmap(jax.vmap(\n", - " b.threedp3_likelihood_per_pixel_jit,\n", - " in_axes=(None, None, None, 0, None, None)),\n", - " in_axes=(None, None, 0, None, None, None)),\n", - " in_axes=(None, 0, None, None, None, None)\n", - "), static_argnames=('filter_size',))" + "threedp3_likelihood_full_hierarchical_bayes_per_pixel_jit = jax.jit(\n", + " jax.vmap(\n", + " jax.vmap(\n", + " jax.vmap(\n", + " b.threedp3_likelihood_per_pixel_jit,\n", + " in_axes=(None, None, None, 0, None, None),\n", + " ),\n", + " in_axes=(None, None, 0, None, None, None),\n", + " ),\n", + " in_axes=(None, 0, None, None, None, None),\n", + " ),\n", + " static_argnames=(\"filter_size\",),\n", + ")" ] }, { @@ -281,7 +287,7 @@ }, "outputs": [], "source": [ - "VARIANCE_GRID = jnp.array([ 0.0001])\n", + "VARIANCE_GRID = jnp.array([0.0001])\n", "OUTLIER_GRID = jnp.array([0.01])\n", "OUTLIER_VOLUME = 1000.0" ] @@ -303,21 +309,35 @@ " )\n", " potential_poses = jnp.concatenate(\n", " [\n", - " jnp.tile(trace.poses[:,None,...], (1,potential_new_object_poses.shape[0],1,1)),\n", - " potential_new_object_poses[None,...]\n", + " jnp.tile(\n", + " trace.poses[:, None, ...],\n", + " (1, potential_new_object_poses.shape[0], 1, 1),\n", + " ),\n", + " potential_new_object_poses[None, ...],\n", " ]\n", " )\n", " traces = b.Traces(\n", - " potential_poses, jnp.concatenate([trace.ids, jnp.array([obj_id])]), VARIANCE_GRID, OUTLIER_GRID,\n", - " trace.outlier_volume, trace.observation\n", + " potential_poses,\n", + " jnp.concatenate([trace.ids, jnp.array([obj_id])]),\n", + " VARIANCE_GRID,\n", + " OUTLIER_GRID,\n", + " trace.outlier_volume,\n", + " trace.observation,\n", " )\n", " p = b.score_traces(traces)\n", "\n", - " ii,jj,kk = jnp.unravel_index(p.argmax(), p.shape)\n", + " ii, jj, kk = jnp.unravel_index(p.argmax(), p.shape)\n", " contact_param = contact_param_grid[ii]\n", - " return contact_param, traces[ii,jj,kk]\n", + " return contact_param, traces[ii, jj, kk]\n", + "\n", "\n", - "refine_jit = jax.jit(refine, static_argnames=(\"i\", \"obj_id\",))" + "refine_jit = jax.jit(\n", + " refine,\n", + " static_argnames=(\n", + " \"i\",\n", + " \"obj_id\",\n", + " ),\n", + ")" ] }, { @@ -336,16 +356,23 @@ "\n", "\n", "gt_trace = b.Trace(\n", - " poses, ids, VARIANCE_GRID[0], OUTLIER_GRID[0], OUTLIER_VOLUME,\n", - " observed_point_cloud_image\n", + " poses,\n", + " ids,\n", + " VARIANCE_GRID[0],\n", + " OUTLIER_GRID[0],\n", + " OUTLIER_VOLUME,\n", + " observed_point_cloud_image,\n", ")\n", "print(b.score_trace(gt_trace))\n", - "b.show_cloud(\"rerender\", b.render_image(gt_trace)[:,:,:3].reshape(-1,3),color=b.RED)\n", + "b.show_cloud(\"rerender\", b.render_image(gt_trace)[:, :, :3].reshape(-1, 3), color=b.RED)\n", "\n", "trace = b.Trace(\n", - " jnp.zeros((0,4,4)), jnp.array([],dtype=jnp.int32),\n", - " VARIANCE_GRID[0], OUTLIER_GRID[0], OUTLIER_VOLUME,\n", - " observed_point_cloud_image\n", + " jnp.zeros((0, 4, 4)),\n", + " jnp.array([], dtype=jnp.int32),\n", + " VARIANCE_GRID[0],\n", + " OUTLIER_GRID[0],\n", + " OUTLIER_VOLUME,\n", + " observed_point_cloud_image,\n", ")\n", "b.viz_trace_meshcat(trace)" ] @@ -369,10 +396,7 @@ " contact_param, trace_ = refine_jit(trace, contact_param, c2f_iter, obj_id)\n", " trace_path.append(trace_)\n", "\n", - " all_paths.append(\n", - " trace_path\n", - " )\n", - "\n", + " all_paths.append(trace_path)\n", "\n", " scores = jnp.array([b.score_trace(t[-1]) for t in all_paths])\n", " normalized_scores = b.utils.normalize_log_scores(scores)\n", @@ -381,7 +405,7 @@ " # print(order)\n", " new_trace = all_paths[jnp.argmax(scores)][-1]\n", " trace = new_trace\n", - " b.viz_trace_meshcat(trace)\n" + " b.viz_trace_meshcat(trace)" ] }, { @@ -435,7 +459,10 @@ }, "outputs": [], "source": [ - "[(b.score_trace(t, renderer),t.variance, t.outlier_prob, t.outlier_volume) for t in all_paths[0]]" + "[\n", + " (b.score_trace(t, renderer), t.variance, t.outlier_prob, t.outlier_volume)\n", + " for t in all_paths[0]\n", + "]" ] }, { @@ -458,7 +485,7 @@ "outputs": [], "source": [ "reconstruction = b.render_image(trace, renderer)\n", - "b.get_depth_image(reconstruction[:,:,2])" + "b.get_depth_image(reconstruction[:, :, 2])" ] }, { @@ -471,9 +498,7 @@ "source": [ "print(trace.variance, trace.outlier_prob, trace.outlier_volume)\n", "p = b.threedp3_likelihood_per_pixel_jit(\n", - " trace.observation, reconstruction[:,:,:3],\n", - " trace.variance, 0.0, 1.0,\n", - " 3\n", + " trace.observation, reconstruction[:, :, :3], trace.variance, 0.0, 1.0, 3\n", ")\n", "outlier_density = jnp.log(trace.outlier_prob) - jnp.log(0.0005)\n", "b.get_depth_image(1.0 * (outlier_density > p), min=0.0, max=1.0)" @@ -524,7 +549,7 @@ "outputs": [], "source": [ "b.clear()\n", - "seg = b.render_image(trace, renderer)[:,:,3]\n", + "seg = b.render_image(trace, renderer)[:, :, 3]\n", "# b.show_cloud(\"rerender\", b.render_image(trace,renderer)[:,:,:3].reshape(-1,3),color=b.RED)" ] }, @@ -560,10 +585,12 @@ "source": [ "inferred_colors = []\n", "distinct_colors = jnp.array(distinct_colors)\n", - "for i in range(1,len(trace.ids)+1):\n", - " seg_colors = rgbd.rgb[seg == i ,:3]\n", - " distances = jnp.abs(seg_colors[:,None,:]/255.0 - distinct_colors[None,...]).sum(-1)\n", - " values, counts = np.unique(jnp.argmin(distances,axis=-1), return_counts=True)\n", + "for i in range(1, len(trace.ids) + 1):\n", + " seg_colors = rgbd.rgb[seg == i, :3]\n", + " distances = jnp.abs(\n", + " seg_colors[:, None, :] / 255.0 - distinct_colors[None, ...]\n", + " ).sum(-1)\n", + " values, counts = np.unique(jnp.argmin(distances, axis=-1), return_counts=True)\n", " inferred_colors.append(values[counts.argmax()])\n", "inferred_colors" ] @@ -662,9 +689,9 @@ }, "outputs": [], "source": [ - "x = b.render_image(gt_trace, renderer)[:,:,:3]\n", + "x = b.render_image(gt_trace, renderer)[:, :, :3]\n", "b.clear()\n", - "b.show_cloud(\"1\", x.reshape(-1,3))" + "b.show_cloud(\"1\", x.reshape(-1, 3))" ] }, { diff --git a/scripts/experiments/deeplearning/dino/dif_interactive.ipynb b/scripts/experiments/deeplearning/dino/dif_interactive.ipynb index 1e3e0cc9..f6b1bab4 100644 --- a/scripts/experiments/deeplearning/dino/dif_interactive.ipynb +++ b/scripts/experiments/deeplearning/dino/dif_interactive.ipynb @@ -43,7 +43,7 @@ "import bayes3d.utils.ycb_loader\n", "from bayes3d.viz.open3dviz import Open3DVisualizer\n", "from tqdm import tqdm\n", - "import open3d as o3d\n" + "import open3d as o3d" ] }, { @@ -70,6 +70,7 @@ "from diffusers import DDIMScheduler\n", "from diffusers import StableDiffusionPipeline\n", "\n", + "\n", "class MyUNet2DConditionModel(UNet2DConditionModel):\n", " def forward(\n", " self,\n", @@ -80,7 +81,8 @@ " class_labels: Optional[torch.Tensor] = None,\n", " timestep_cond: Optional[torch.Tensor] = None,\n", " attention_mask: Optional[torch.Tensor] = None,\n", - " cross_attention_kwargs: Optional[Dict[str, Any]] = None):\n", + " cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n", + " ):\n", " r\"\"\"\n", " Args:\n", " sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor\n", @@ -142,7 +144,9 @@ "\n", " if self.class_embedding is not None:\n", " if class_labels is None:\n", - " raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n", + " raise ValueError(\n", + " \"class_labels should be provided when num_class_embeds > 0\"\n", + " )\n", "\n", " if self.config.class_embed_type == \"timestep\":\n", " class_labels = self.time_proj(class_labels)\n", @@ -156,7 +160,10 @@ " # 3. down\n", " down_block_res_samples = (sample,)\n", " for downsample_block in self.down_blocks:\n", - " if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n", + " if (\n", + " hasattr(downsample_block, \"has_cross_attention\")\n", + " and downsample_block.has_cross_attention\n", + " ):\n", " sample, res_samples = downsample_block(\n", " hidden_states=sample,\n", " temb=emb,\n", @@ -182,21 +189,25 @@ " # 5. up\n", " up_ft = {}\n", " for i, upsample_block in enumerate(self.up_blocks):\n", - "\n", " if i > np.max(up_ft_indices):\n", " break\n", "\n", " is_final_block = i == len(self.up_blocks) - 1\n", "\n", " res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n", - " down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n", + " down_block_res_samples = down_block_res_samples[\n", + " : -len(upsample_block.resnets)\n", + " ]\n", "\n", " # if we have not reached the final block and need to forward the\n", " # upsample size, we do it here\n", " if not is_final_block and forward_upsample_size:\n", " upsample_size = down_block_res_samples[-1].shape[2:]\n", "\n", - " if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n", + " if (\n", + " hasattr(upsample_block, \"has_cross_attention\")\n", + " and upsample_block.has_cross_attention\n", + " ):\n", " sample = upsample_block(\n", " hidden_states=sample,\n", " temb=emb,\n", @@ -208,16 +219,20 @@ " )\n", " else:\n", " sample = upsample_block(\n", - " hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n", + " hidden_states=sample,\n", + " temb=emb,\n", + " res_hidden_states_tuple=res_samples,\n", + " upsample_size=upsample_size,\n", " )\n", "\n", " if i in up_ft_indices:\n", " up_ft[i] = sample.detach()\n", "\n", " output = {}\n", - " output['up_ft'] = up_ft\n", + " output[\"up_ft\"] = up_ft\n", " return output\n", "\n", + "\n", "class OneStepSDPipeline(StableDiffusionPipeline):\n", " @torch.no_grad()\n", " def __call__(\n", @@ -230,28 +245,36 @@ " prompt_embeds: Optional[torch.FloatTensor] = None,\n", " callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n", " callback_steps: int = 1,\n", - " cross_attention_kwargs: Optional[Dict[str, Any]] = None\n", + " cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n", " ):\n", - "\n", " device = self._execution_device\n", - " latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor\n", + " latents = (\n", + " self.vae.encode(img_tensor).latent_dist.sample()\n", + " * self.vae.config.scaling_factor\n", + " )\n", " t = torch.tensor(t, dtype=torch.long, device=device)\n", " noise = torch.randn_like(latents).to(device)\n", " latents_noisy = self.scheduler.add_noise(latents, noise, t)\n", - " unet_output = self.unet(latents_noisy,\n", - " t,\n", - " up_ft_indices,\n", - " encoder_hidden_states=prompt_embeds,\n", - " cross_attention_kwargs=cross_attention_kwargs)\n", + " unet_output = self.unet(\n", + " latents_noisy,\n", + " t,\n", + " up_ft_indices,\n", + " encoder_hidden_states=prompt_embeds,\n", + " cross_attention_kwargs=cross_attention_kwargs,\n", + " )\n", " return unet_output\n", "\n", "\n", "class SDFeaturizer:\n", - " def __init__(self, sd_id='stabilityai/stable-diffusion-2-1'):\n", + " def __init__(self, sd_id=\"stabilityai/stable-diffusion-2-1\"):\n", " unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder=\"unet\")\n", - " onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)\n", + " onestep_pipe = OneStepSDPipeline.from_pretrained(\n", + " sd_id, unet=unet, safety_checker=None\n", + " )\n", " onestep_pipe.vae.decoder = None\n", - " onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder=\"scheduler\")\n", + " onestep_pipe.scheduler = DDIMScheduler.from_pretrained(\n", + " sd_id, subfolder=\"scheduler\"\n", + " )\n", " gc.collect()\n", " onestep_pipe = onestep_pipe.to(\"cuda\")\n", " onestep_pipe.enable_attention_slicing()\n", @@ -259,27 +282,31 @@ " self.pipe = onestep_pipe\n", "\n", " @torch.no_grad()\n", - " def forward(self,\n", - " img_tensor, # single image, [1,c,h,w]\n", - " prompt,\n", - " t=261,\n", - " up_ft_index=1,\n", - " ensemble_size=8):\n", + " def forward(\n", + " self,\n", + " img_tensor, # single image, [1,c,h,w]\n", + " prompt,\n", + " t=261,\n", + " up_ft_index=1,\n", + " ensemble_size=8,\n", + " ):\n", " print(img_tensor.shape)\n", - " img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w\n", + " img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w\n", " prompt_embeds = self.pipe.encode_prompt(\n", " prompt=prompt,\n", - " device='cuda',\n", + " device=\"cuda\",\n", " num_images_per_prompt=1,\n", - " do_classifier_free_guidance=False)[0] # [1, 77, dim]\n", + " do_classifier_free_guidance=False,\n", + " )[0] # [1, 77, dim]\n", " prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)\n", " unet_ft_all = self.pipe(\n", " img_tensor=img_tensor,\n", " t=t,\n", " up_ft_indices=[up_ft_index],\n", - " prompt_embeds=prompt_embeds)\n", - " unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w\n", - " unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w\n", + " prompt_embeds=prompt_embeds,\n", + " )\n", + " unet_ft = unet_ft_all[\"up_ft\"][up_ft_index] # ensem, c, h, w\n", + " unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w\n", " return unet_ft" ] }, @@ -324,15 +351,11 @@ } ], "source": [ - "w,h = 400,400\n", + "w, h = 400, 400\n", "intrinsics = b.Intrinsics(\n", - " height=h,\n", - " width=w,\n", - " fx=2000.0, fy=2000.0,\n", - " cx=w/2.0, cy=h/2.0,\n", - " near=0.001, far=6.0\n", + " height=h, width=w, fx=2000.0, fy=2000.0, cx=w / 2.0, cy=h / 2.0, near=0.001, far=6.0\n", ")\n", - "scaled_down_intrinsics = b.camera.scale_camera_parameters(intrinsics, 1.0/14.0)\n", + "scaled_down_intrinsics = b.camera.scale_camera_parameters(intrinsics, 1.0 / 14.0)\n", "scaled_down_intrinsics" ] }, @@ -372,11 +395,13 @@ }, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "mesh_paths = []\n", - "for idx in range(1,22):\n", - " mesh_paths.append(os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\"))\n", - "SCALING_FACTOR = 1.0/1000.0\n" + "for idx in range(1, 22):\n", + " mesh_paths.append(\n", + " os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", + " )\n", + "SCALING_FACTOR = 1.0 / 1000.0" ] }, { @@ -397,7 +422,7 @@ "metadata": {}, "outputs": [], "source": [ - "viz = Open3DVisualizer(intrinsics)\n" + "viz = Open3DVisualizer(intrinsics)" ] }, { @@ -431,11 +456,19 @@ "metadata": {}, "outputs": [], "source": [ - "object_poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 2.6, 0.0]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) for angle in np.linspace(-jnp.pi, jnp.pi, 101)[:-1]])\n", + "object_poses = jnp.array(\n", + " [\n", + " b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 2.6, 0.0]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + " )\n", + " @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " for angle in np.linspace(-jnp.pi, jnp.pi, 101)[:-1]\n", + " ]\n", + ")\n", "# for (i, pose) in enumerate(object_poses):\n", "# b.show_pose(f\"{i}\", pose)" ] @@ -459,7 +492,7 @@ "for i, pose in tqdm(enumerate(object_poses)):\n", " # if i > 0:\n", " # mesh.meshes[0].mesh.transform(b.inv\n", - " \n", + "\n", " # viz.render.scene.add_model(f\"1\", mesh)\n", " rgbd = viz.capture_image(intrinsics, b.t3d.inverse_pose(pose))\n", " images.append(rgbd)\n", @@ -514,7 +547,7 @@ } ], "source": [ - "dift = SDFeaturizer(sd_id='stabilityai/stable-diffusion-2-1')\n", + "dift = SDFeaturizer(sd_id=\"stabilityai/stable-diffusion-2-1\")\n", "\n", "# heatmap = get_heatmap(img1_embedding[45,45], img2_embedding)\n", "# scaled_up_heatmap = b.utils.resize(heatmap, img2.shape[0], img2.shape[1])\n", @@ -531,20 +564,21 @@ "def get_embeddings(rgb):\n", " img = b.get_rgb_image(rgb)\n", " img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2\n", - " img_feat_norm = dift.forward(img_tensor,\n", - " prompt=f\"find the object\",\n", - " ensemble_size=2)\n", - " img_feat_norm = nn.Upsample(size=(rgb.shape[0],rgb.shape[1]), mode='bilinear')(img_feat_norm)\n", + " img_feat_norm = dift.forward(img_tensor, prompt=f\"find the object\", ensemble_size=2)\n", + " img_feat_norm = nn.Upsample(size=(rgb.shape[0], rgb.shape[1]), mode=\"bilinear\")(\n", + " img_feat_norm\n", + " )\n", " print(img_feat_norm.shape)\n", " output = jnp.array(img_feat_norm.cpu().detach().numpy())[0]\n", " del img_feat_norm\n", " del img_tensor\n", - " return jnp.transpose(output, (1,2,0))\n", + " return jnp.transpose(output, (1, 2, 0))\n", + "\n", "\n", "def get_heatmap(target_embedding, all_embeddings):\n", " dot_products = jnp.einsum(\"i, abi->ab\", target_embedding, all_embeddings)\n", " heatmap = dot_products / 2.0 + 0.5\n", - " return heatmap\n" + " return heatmap" ] }, { @@ -577,13 +611,13 @@ ], "source": [ "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('55', '22', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\"55\", \"22\", bop_ycb_dir)\n", "\n", "img2 = rgbd.rgb\n", "\n", "img2_embedding = get_embeddings(img2)\n", "print(img2_embedding.shape)\n", - "b.get_rgb_image(rgbd.rgb)\n" + "b.get_rgb_image(rgbd.rgb)" ] }, { @@ -679,8 +713,8 @@ } ], "source": [ - "heatmap = get_heatmap(img1_embedding[70,70], img1_embedding)\n", - "plt.matshow(heatmap)\n" + "heatmap = get_heatmap(img1_embedding[70, 70], img1_embedding)\n", + "plt.matshow(heatmap)" ] }, { @@ -794,32 +828,33 @@ "img1 = images[50].rgb\n", "img1_embedding = get_embeddings(img1)\n", "\n", - "img1_embedding = img1_embedding / jnp.linalg.norm(img1_embedding, axis=-1)[...,None]\n", - "img2_embedding = img2_embedding / jnp.linalg.norm(img2_embedding, axis=-1)[...,None]\n", + "img1_embedding = img1_embedding / jnp.linalg.norm(img1_embedding, axis=-1)[..., None]\n", + "img2_embedding = img2_embedding / jnp.linalg.norm(img2_embedding, axis=-1)[..., None]\n", "\n", - "axes[0].imshow(img1 /255.0)\n", + "axes[0].imshow(img1 / 255.0)\n", "axes[1].imshow(img2 / 255.0)\n", + "\n", + "\n", "def onclick(event):\n", " print(\"hello\")\n", " # cos = nn.CosineSimilarity(dim=1)\n", "\n", " x, y = int(np.round(event.xdata)), int(np.round(event.ydata))\n", - " x2, y2 = int(np.round(event.xdata)), int(np.round(event.ydata ))\n", - " print(x,y)\n", - " data.append((x,y))\n", + " x2, y2 = int(np.round(event.xdata)), int(np.round(event.ydata))\n", + " print(x, y)\n", + " data.append((x, y))\n", "\n", - " target_embedding = img1_embedding[y2, x2,:]\n", + " target_embedding = img1_embedding[y2, x2, :]\n", " heatmap = get_heatmap(target_embedding, img1_embedding)\n", "\n", " # dot_products = jnp.einsum(\"i, abi->ab\", target_embedding, img2_embedding)\n", "\n", - "\n", " axes[0].clear()\n", " axes[1].clear()\n", - " axes[0].imshow(img1 /255.0)\n", + " axes[0].imshow(img1 / 255.0)\n", " axes[1].imshow(img2 / 255.0)\n", - " axes[0].imshow(255 * heatmap, alpha=0.45, cmap='viridis')\n", - " axes[0].scatter(x, y, c='r', s=10.0)\n", + " axes[0].imshow(255 * heatmap, alpha=0.45, cmap=\"viridis\")\n", + " axes[0].scatter(x, y, c=\"r\", s=10.0)\n", "\n", " heatmap = get_heatmap(target_embedding, img2_embedding)\n", " scaled_up_heatmap = b.utils.resize(heatmap, img2.shape[0], img2.shape[1])\n", @@ -827,13 +862,14 @@ " max_yx = np.unravel_index(scaled_up_heatmap.argmax(), scaled_up_heatmap.shape)\n", " # axes[1].axis('off')\n", " print(\"max \", scaled_up_heatmap.max())\n", - " axes[1].imshow(255 * scaled_up_heatmap, alpha=0.45, cmap='viridis')\n", + " axes[1].imshow(255 * scaled_up_heatmap, alpha=0.45, cmap=\"viridis\")\n", " # axes[1].axis('off')\n", - " axes[1].scatter(max_yx[1], max_yx[0], c='r', s=10)\n", + " axes[1].scatter(max_yx[1], max_yx[0], c=\"r\", s=10)\n", " # axes[1].set_title('target image')\n", " # gc.collect()\n", "\n", - "fig.canvas.mpl_connect('button_press_event', onclick)\n", + "\n", + "fig.canvas.mpl_connect(\"button_press_event\", onclick)\n", "plt.show()" ] }, @@ -911,12 +947,14 @@ "x = np.linspace(0, 2 * np.pi)\n", "fig = plt.figure()\n", "ax = fig.add_subplot(1, 1, 1)\n", - "line, = ax.plot(x, np.sin(x))\n", + "(line,) = ax.plot(x, np.sin(x))\n", "\n", - "def update(w = 1.0):\n", + "\n", + "def update(w=1.0):\n", " line.set_ydata(np.sin(w * x))\n", " fig.canvas.draw_idle()\n", "\n", + "\n", "interact(update);" ] }, @@ -950,7 +988,7 @@ "source": [ "num_images = len(images)\n", "num_training_images = 10\n", - "training_indices = jnp.arange(0,num_images-1, num_images // num_training_images)\n", + "training_indices = jnp.arange(0, num_images - 1, num_images // num_training_images)\n", "# b.hstack_images([\n", "# b.get_rgb_image(images[idx].rgb) for idx in training_indices\n", "# ])" @@ -970,44 +1008,69 @@ "sparse_descriptors = []\n", "for iteration in range(len(training_indices)):\n", " index = training_indices[iteration]\n", - " index_next = training_indices[(iteration+1) % len(training_indices)]\n", + " index_next = training_indices[(iteration + 1) % len(training_indices)]\n", " print(index, index_next)\n", " keys = jax.random.split(key)[1]\n", - " \n", + "\n", " training_image = images[index]\n", " object_pose = object_poses[index]\n", - " \n", - " scaled_down_training_image = training_image.scale_rgbd(1.0/14.0)\n", + "\n", + " scaled_down_training_image = training_image.scale_rgbd(1.0 / 14.0)\n", " embedding_image = get_embeddings(training_image)\n", " embedding_image_next = get_embeddings(images[index_next])\n", - " \n", - " foreground_mask = (jnp.inf != scaled_down_training_image.depth)\n", + "\n", + " foreground_mask = jnp.inf != scaled_down_training_image.depth\n", " foreground_pixel_coordinates = jnp.transpose(jnp.vstack(jnp.where(foreground_mask)))\n", - " \n", + "\n", " depth = jnp.array(scaled_down_training_image.depth)\n", " depth = depth.at[depth == jnp.inf].set(0.0)\n", - " point_cloud_image = b.t3d.unproject_depth(depth, scaled_down_training_image.intrinsics)\n", - " point_cloud_image_object_frame = b.t3d.apply_transform(point_cloud_image, b.t3d.inverse_pose(object_pose))\n", - " \n", - " scaled_down_training_image_next = images[index_next].scale_rgbd(1.0/14.0)\n", + " point_cloud_image = b.t3d.unproject_depth(\n", + " depth, scaled_down_training_image.intrinsics\n", + " )\n", + " point_cloud_image_object_frame = b.t3d.apply_transform(\n", + " point_cloud_image, b.t3d.inverse_pose(object_pose)\n", + " )\n", + "\n", + " scaled_down_training_image_next = images[index_next].scale_rgbd(1.0 / 14.0)\n", " depth = jnp.array(scaled_down_training_image_next.depth)\n", " depth = depth.at[depth == jnp.inf].set(0.0)\n", - " point_cloud_image_next = b.t3d.unproject_depth(depth, scaled_down_training_image_next.intrinsics)\n", - " point_cloud_image_next_object_frame = b.t3d.apply_transform(point_cloud_image_next, b.t3d.inverse_pose(object_poses[index_next]))\n", - " \n", - " embeddings_subset = embedding_image[foreground_pixel_coordinates[:,0], foreground_pixel_coordinates[:,1],:]\n", - " coordinates_subset = point_cloud_image_object_frame[foreground_pixel_coordinates[:,0], foreground_pixel_coordinates[:,1],:]\n", - " similarity_embedding = jnp.einsum(\"abi, ki->abk\", embedding_image_next, embeddings_subset)\n", + " point_cloud_image_next = b.t3d.unproject_depth(\n", + " depth, scaled_down_training_image_next.intrinsics\n", + " )\n", + " point_cloud_image_next_object_frame = b.t3d.apply_transform(\n", + " point_cloud_image_next, b.t3d.inverse_pose(object_poses[index_next])\n", + " )\n", + "\n", + " embeddings_subset = embedding_image[\n", + " foreground_pixel_coordinates[:, 0], foreground_pixel_coordinates[:, 1], :\n", + " ]\n", + " coordinates_subset = point_cloud_image_object_frame[\n", + " foreground_pixel_coordinates[:, 0], foreground_pixel_coordinates[:, 1], :\n", + " ]\n", + " similarity_embedding = jnp.einsum(\n", + " \"abi, ki->abk\", embedding_image_next, embeddings_subset\n", + " )\n", " best_match = similarity_embedding.argmax(-1)\n", - " distance_to_best_match = jnp.linalg.norm(point_cloud_image_next_object_frame - coordinates_subset[best_match,:], axis=-1)\n", - " \n", + " distance_to_best_match = jnp.linalg.norm(\n", + " point_cloud_image_next_object_frame - coordinates_subset[best_match, :], axis=-1\n", + " )\n", + "\n", " selected = (distance_to_best_match < 0.01) * (similarity_embedding.max(-1) > 0.9)\n", " subset = jnp.unique(best_match[selected])\n", "\n", - "\n", - " _keypoint_embeddings = embedding_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " keypoint_world_coordinates = point_cloud_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " _keypoint_coordinates = b.t3d.apply_transform(keypoint_world_coordinates, b.t3d.inverse_pose(object_pose))\n", + " _keypoint_embeddings = embedding_image[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", + " keypoint_world_coordinates = point_cloud_image[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", + " _keypoint_coordinates = b.t3d.apply_transform(\n", + " keypoint_world_coordinates, b.t3d.inverse_pose(object_pose)\n", + " )\n", "\n", " keypoint_coordinates.append(_keypoint_coordinates)\n", " keypoint_embeddings.append(_keypoint_embeddings)\n", @@ -1040,10 +1103,10 @@ "# index_next = training_indices[(iteration+1) % len(training_indices)]\n", "# print(index, index_next)\n", "# keys = jax.random.split(key)[1]\n", - " \n", + "\n", "# training_image = images[index]\n", "# object_pose = object_poses[index]\n", - " \n", + "\n", "# scaled_down_training_image = training_image.scale_rgbd(1.0/14.0)\n", "# embedding_image = get_embeddings(training_image)\n", "\n", @@ -1061,7 +1124,7 @@ "# keypoint_world_coordinates = point_cloud_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", "# _keypoint_coordinates = b.t3d.apply_transform(keypoint_world_coordinates, b.t3d.inverse_pose(object_pose))\n", "# _keypoint_embeddings = embedding_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " \n", + "\n", "# keypoint_coordinates.append(_keypoint_coordinates)\n", "# keypoint_embeddings.append(_keypoint_embeddings)\n", "# del embedding_image\n", @@ -1078,24 +1141,29 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings):\n", - " point_cloud_img = b.RENDERER.render(pose[None,...], jnp.array([0]))[:,:,:3]\n", - " point_cloud_img_in_object_frame = b.t3d.apply_transform(point_cloud_img, b.t3d.inverse_pose(pose))\n", - "\n", - " distances_to_keypoints = (\n", - " jnp.linalg.norm(point_cloud_img_in_object_frame[:, :,None,...] - keypoint_coordinates[None, None,:,...],\n", - " axis=-1\n", - " ))\n", + " point_cloud_img = b.RENDERER.render(pose[None, ...], jnp.array([0]))[:, :, :3]\n", + " point_cloud_img_in_object_frame = b.t3d.apply_transform(\n", + " point_cloud_img, b.t3d.inverse_pose(pose)\n", + " )\n", + "\n", + " distances_to_keypoints = jnp.linalg.norm(\n", + " point_cloud_img_in_object_frame[:, :, None, ...]\n", + " - keypoint_coordinates[None, None, :, ...],\n", + " axis=-1,\n", + " )\n", " index_of_nearest_keypoint = distances_to_keypoints.argmin(2)\n", " distance_to_nearest_keypoints = distances_to_keypoints.min(2)\n", "\n", " DISTANCE_THRESHOLD = 0.04\n", - " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[...,None]\n", + " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[..., None]\n", " selected_keypoints = keypoint_coordinates[index_of_nearest_keypoint]\n", - " rendered_embeddings_image = keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " rendered_embeddings_image = (\n", + " keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " )\n", " return point_cloud_img, rendered_embeddings_image\n", "\n", + "\n", "vmf_score = lambda q, q_mean, conc: tfp.distributions.VonMisesFisher(\n", " q_mean, conc\n", ").log_prob(q)\n", @@ -1107,28 +1175,30 @@ "\n", "@functools.partial(\n", " jnp.vectorize,\n", - " signature='(m),(m)->()',\n", + " signature=\"(m),(m)->()\",\n", " excluded=(2,),\n", ")\n", - "def vmf_vectorize(\n", - " embeddings,\n", - " embeddings_mean,\n", - " conc\n", - "):\n", + "def vmf_vectorize(embeddings, embeddings_mean, conc):\n", " return vmf_score(embeddings, embeddings_mean, conc)\n", "\n", "\n", "def score_pose(pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings):\n", - " _,rendered_embedding_image = render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings)\n", + " _, rendered_embedding_image = render_embedding_image(\n", + " pose, keypoint_coordinates, keypoint_embeddings\n", + " )\n", " scores = vmf_vectorize(observed_embeddings, rendered_embedding_image, 1000.0)\n", " return scores\n", "\n", + "\n", "def get_pca(embeddings):\n", - " features_flat = torch.from_numpy(np.array(embeddings).reshape(-1, embeddings.shape[-1]))\n", + " features_flat = torch.from_numpy(\n", + " np.array(embeddings).reshape(-1, embeddings.shape[-1])\n", + " )\n", " U, S, V = torch.pca_lowrank(features_flat - features_flat.mean(0), niter=10)\n", " proj_PCA = jnp.array(V[:, :3])\n", " return proj_PCA\n", "\n", + "\n", "def get_colors(features, proj_V):\n", " features_flat = features.reshape(-1, features.shape[-1])\n", " feat_rgb = features_flat @ proj_V\n", @@ -1136,8 +1206,9 @@ " feat_rgb = feat_rgb.reshape(features.shape[:-1] + (3,))\n", " return feat_rgb\n", "\n", + "\n", "score_pose_jit = jax.jit(score_pose)\n", - "score_pose_parallel_jit = jax.jit(jax.vmap(score_pose, in_axes=(0, None, None, None )))" + "score_pose_parallel_jit = jax.jit(jax.vmap(score_pose, in_axes=(0, None, None, None)))" ] }, { @@ -1154,7 +1225,9 @@ "proj_V = get_pca(keypoint_embeddings)\n", "colors = get_colors(keypoint_embeddings, proj_V)\n", "b.clear()\n", - "obj = g.PointCloud(np.transpose(keypoint_coordinates)*10.0, np.transpose(colors), size=0.1)\n", + "obj = g.PointCloud(\n", + " np.transpose(keypoint_coordinates) * 10.0, np.transpose(colors), size=0.1\n", + ")\n", "b.meshcatviz.VISUALIZER[\"2\"].set_object(obj)" ] }, @@ -1180,8 +1253,10 @@ "metadata": {}, "outputs": [], "source": [ - "b.RENDERER.render(jnp.eye(4)[None,...], jnp.array([0]));\n", - "pc_img, rendered_embedding_image = render_embedding_image(object_poses[0], keypoint_coordinates, keypoint_embeddings);" + "b.RENDERER.render(jnp.eye(4)[None, ...], jnp.array([0]))\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " object_poses[0], keypoint_coordinates, keypoint_embeddings\n", + ")" ] }, { @@ -1195,7 +1270,7 @@ "source": [ "IDX = 15\n", "test_rgbd = images[IDX]\n", - "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0/14.0)\n", + "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0 / 14.0)\n", "observed_embeddings = get_embeddings(test_rgbd)\n", "# b.get_rgb_image(test_rgbd.rgb)" ] @@ -1209,7 +1284,14 @@ }, "outputs": [], "source": [ - "posterior = jnp.concatenate([score_pose_parallel_jit(i, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[:,test_rgbd_scaled.depth != jnp.inf].mean(-1) for i in jnp.array_split(object_poses, 10)])\n", + "posterior = jnp.concatenate(\n", + " [\n", + " score_pose_parallel_jit(\n", + " i, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[:, test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " for i in jnp.array_split(object_poses, 10)\n", + " ]\n", + ")\n", "print(posterior.argmax())\n", "best_pose = object_poses[posterior.argmax()]\n", "print(best_pose)" @@ -1245,21 +1327,22 @@ "observed_embedding_colors = get_colors(observed_embeddings, proj_V)\n", "observed_embeddings_image_viz = b.get_rgb_image(observed_embedding_colors * 255.0)\n", "\n", - "pc_img, rendered_embedding_image = render_embedding_image(best_pose, keypoint_coordinates, keypoint_embeddings)\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " best_pose, keypoint_coordinates, keypoint_embeddings\n", + ")\n", "colors = get_colors(rendered_embedding_image, proj_V)\n", "rgba = jnp.array(b.get_rgb_image(colors * 255.0))\n", "# rgba = rgba.at[pc_img[:,:,2] > intrinsics.far - 0.01, :3].set(255.0)\n", "rerendered_embeddings_viz = b.get_rgb_image(rgba)\n", "\n", - "b.multi_panel([\n", - " b.get_rgb_image(test_rgbd.rgb), \n", - " b.scale_image(observed_embeddings_image_viz, 14.0),\n", - " b.scale_image(rerendered_embeddings_viz, 14.0)\n", - "],labels=[\n", - " \"Observed RGB\",\n", - " \"Embeddings\",\n", - " \"Reconstruction\"\n", - "],label_fontsize=50\n", + "b.multi_panel(\n", + " [\n", + " b.get_rgb_image(test_rgbd.rgb),\n", + " b.scale_image(observed_embeddings_image_viz, 14.0),\n", + " b.scale_image(rerendered_embeddings_viz, 14.0),\n", + " ],\n", + " labels=[\"Observed RGB\", \"Embeddings\", \"Reconstruction\"],\n", + " label_fontsize=50,\n", ").convert(\"RGB\")" ] }, @@ -1270,9 +1353,11 @@ "metadata": {}, "outputs": [], "source": [ - "random_pose = b.transform_from_pos(jnp.array([0.0, 0.0, 0.6])) @ b.distributions.vmf_jit(jax.random.PRNGKey(40), 0.001)\n", + "random_pose = b.transform_from_pos(\n", + " jnp.array([0.0, 0.0, 0.6])\n", + ") @ b.distributions.vmf_jit(jax.random.PRNGKey(40), 0.001)\n", "test_rgbd = viz.capture_image(intrinsics, b.t3d.inverse_pose(random_pose))\n", - "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0/14.0)\n", + "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0 / 14.0)\n", "observed_embeddings = get_embeddings(test_rgbd)\n", "b.get_rgb_image(test_rgbd.scale_rgbd(0.2).rgb)" ] @@ -1304,12 +1389,14 @@ "metadata": {}, "outputs": [], "source": [ - "match_scores = jnp.einsum(\"abk,ck\",observed_embeddings, keypoint_embeddings)\n", + "match_scores = jnp.einsum(\"abk,ck\", observed_embeddings, keypoint_embeddings)\n", "top_match = match_scores.max(-1)\n", "top_match_idx = match_scores.argmax(-1)\n", "\n", "THRESHOLD = 0.8\n", - "match_mask = (top_match > THRESHOLD) * (test_rgbd_scaled.depth < test_rgbd_scaled.intrinsics.far)\n", + "match_mask = (top_match > THRESHOLD) * (\n", + " test_rgbd_scaled.depth < test_rgbd_scaled.intrinsics.far\n", + ")\n", "print(match_mask.sum())\n", "b.get_depth_image(1.0 * match_mask)" ] @@ -1321,14 +1408,16 @@ "metadata": {}, "outputs": [], "source": [ - "observed_point_cloud_image = b.unproject_depth_jit(test_rgbd_scaled.depth, test_rgbd_scaled.intrinsics)\n", + "observed_point_cloud_image = b.unproject_depth_jit(\n", + " test_rgbd_scaled.depth, test_rgbd_scaled.intrinsics\n", + ")\n", "\n", - "observed_match_coordinates = observed_point_cloud_image[match_mask,:]\n", - "model_coordinates = keypoint_coordinates[top_match_idx[match_mask],:]\n", + "observed_match_coordinates = observed_point_cloud_image[match_mask, :]\n", + "model_coordinates = keypoint_coordinates[top_match_idx[match_mask], :]\n", "\n", "b.clear()\n", - "b.show_cloud(\"1\", observed_match_coordinates.reshape(-1,3))\n", - "b.show_cloud(\"2\", model_coordinates.reshape(-1,3), color=b.RED)" + "b.show_cloud(\"1\", observed_match_coordinates.reshape(-1, 3))\n", + "b.show_cloud(\"2\", model_coordinates.reshape(-1, 3), color=b.RED)" ] }, { @@ -1339,11 +1428,13 @@ "outputs": [], "source": [ "b.clear()\n", - "estimated_pose = b.estimate_transform_between_clouds(model_coordinates, observed_match_coordinates)\n", - "estimated_pose = b.distributions.gaussian_vmf_jit(keys[10],random_pose, 0.1, 10.0)\n", + "estimated_pose = b.estimate_transform_between_clouds(\n", + " model_coordinates, observed_match_coordinates\n", + ")\n", + "estimated_pose = b.distributions.gaussian_vmf_jit(keys[10], random_pose, 0.1, 10.0)\n", "b.show_trimesh(\"mesh\", b.RENDERER.meshes[0])\n", "b.set_pose(\"mesh\", estimated_pose)\n", - "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { @@ -1353,8 +1444,16 @@ "metadata": {}, "outputs": [], "source": [ - "print(score_pose(random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))\n", - "print(score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))" + "print(\n", + " score_pose(\n", + " random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")\n", + "print(\n", + " score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")" ] }, { @@ -1375,16 +1474,20 @@ "outputs": [], "source": [ "for _ in range(20):\n", - " potential_poses = gaussian_vmf_parallel(keys,estimated_pose, 0.01, 20000.0)\n", - " current_score = score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", - " scores = score_pose_parallel_jit(potential_poses, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[:,test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " potential_poses = gaussian_vmf_parallel(keys, estimated_pose, 0.01, 20000.0)\n", + " current_score = score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " scores = score_pose_parallel_jit(\n", + " potential_poses, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[:, test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", " if scores.max() > current_score:\n", " estimated_pose = potential_poses[scores.argmax()]\n", " keys = split_jit(keys[0], 100)\n", " print(scores.max(), current_score)\n", " b.show_trimesh(\"mesh\", b.RENDERER.meshes[0])\n", " b.set_pose(\"mesh\", estimated_pose)\n", - " b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + " b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { @@ -1394,8 +1497,16 @@ "metadata": {}, "outputs": [], "source": [ - "print(score_pose(random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))\n", - "print(score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))" + "print(\n", + " score_pose(\n", + " random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")\n", + "print(\n", + " score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")" ] }, { @@ -1416,21 +1527,22 @@ "observed_embedding_colors = get_colors(observed_embeddings, proj_V)\n", "observed_embeddings_image_viz = b.get_rgb_image(observed_embedding_colors * 255.0)\n", "\n", - "pc_img, rendered_embedding_image = render_embedding_image(estimated_pose, keypoint_coordinates, keypoint_embeddings)\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings\n", + ")\n", "colors = get_colors(rendered_embedding_image, proj_V)\n", "rgba = jnp.array(b.get_rgb_image(colors * 255.0))\n", "# rgba = rgba.at[pc_img[:,:,2] > intrinsics.far - 0.01, :3].set(255.0)\n", "rerendered_embeddings_viz = b.get_rgb_image(rgba)\n", "\n", - "b.multi_panel([\n", - " b.get_rgb_image(test_rgbd.rgb), \n", - " b.scale_image(observed_embeddings_image_viz, 14.0),\n", - " b.scale_image(rerendered_embeddings_viz, 14.0)\n", - "],labels=[\n", - " \"Observed RGB\",\n", - " \"Embeddings\",\n", - " \"Reconstruction\"\n", - "],label_fontsize=50\n", + "b.multi_panel(\n", + " [\n", + " b.get_rgb_image(test_rgbd.rgb),\n", + " b.scale_image(observed_embeddings_image_viz, 14.0),\n", + " b.scale_image(rerendered_embeddings_viz, 14.0),\n", + " ],\n", + " labels=[\"Observed RGB\", \"Embeddings\", \"Reconstruction\"],\n", + " label_fontsize=50,\n", ").convert(\"RGB\")" ] }, @@ -1443,7 +1555,7 @@ "source": [ "# b.clear()\n", "# b.show_trimesh(\"mesh\", b.RENDERER.meshes[obj_idx])\n", - "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { diff --git a/scripts/experiments/deeplearning/dino/dino_interactive.ipynb b/scripts/experiments/deeplearning/dino/dino_interactive.ipynb index 30fd6cb0..f55a2773 100644 --- a/scripts/experiments/deeplearning/dino/dino_interactive.ipynb +++ b/scripts/experiments/deeplearning/dino/dino_interactive.ipynb @@ -44,7 +44,7 @@ "import bayes3d.utils.ycb_loader\n", "from bayes3d.viz.open3dviz import Open3DVisualizer\n", "from tqdm import tqdm\n", - "import open3d as o3d\n" + "import open3d as o3d" ] }, { @@ -69,7 +69,7 @@ "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')\n", + "dinov2_vitg14 = torch.hub.load(\"facebookresearch/dinov2\", \"dinov2_vits14\")\n", "dino = dinov2_vitg14.to(device) # Same issue with larger model" ] }, @@ -83,19 +83,26 @@ "outputs": [], "source": [ "def get_embeddings(dinov2_vitg14, rgb):\n", - " img = b.get_rgb_image(rgb).convert('RGB')\n", + " img = b.get_rgb_image(rgb).convert(\"RGB\")\n", " patch_w, patch_h = np.array(img.size) // 14\n", - " transform = T.Compose([\n", - " T.GaussianBlur(9, sigma=(0.1, 2.0)),\n", - " T.Resize((patch_h * 14, patch_w * 14)),\n", - " T.CenterCrop((patch_h * 14, patch_w * 14)),\n", - " T.ToTensor(),\n", - " T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n", - " ])\n", + " transform = T.Compose(\n", + " [\n", + " T.GaussianBlur(9, sigma=(0.1, 2.0)),\n", + " T.Resize((patch_h * 14, patch_w * 14)),\n", + " T.CenterCrop((patch_h * 14, patch_w * 14)),\n", + " T.ToTensor(),\n", + " T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n", + " ]\n", + " )\n", " tensor = transform(img)[:3].unsqueeze(0).to(device)\n", " with torch.no_grad():\n", " features_dict = dinov2_vitg14.forward_features(tensor)\n", - " features = features_dict['x_norm_patchtokens'][0].reshape((patch_h, patch_w, 384)).permute(2, 0, 1).unsqueeze(0)\n", + " features = (\n", + " features_dict[\"x_norm_patchtokens\"][0]\n", + " .reshape((patch_h, patch_w, 384))\n", + " .permute(2, 0, 1)\n", + " .unsqueeze(0)\n", + " )\n", " img_feat_norm = torch.nn.functional.normalize(features, dim=1)\n", " output = jnp.array(img_feat_norm.cpu().detach().numpy())[0]\n", " del img_feat_norm\n", @@ -103,7 +110,7 @@ " del tensor\n", " del features_dict\n", " torch.cuda.empty_cache()\n", - " return jnp.transpose(output, (1,2,0))" + " return jnp.transpose(output, (1, 2, 0))" ] }, { @@ -115,15 +122,11 @@ }, "outputs": [], "source": [ - "w,h = 1400,1400\n", + "w, h = 1400, 1400\n", "intrinsics = b.Intrinsics(\n", - " height=h,\n", - " width=w,\n", - " fx=2000.0, fy=2000.0,\n", - " cx=w/2.0, cy=h/2.0,\n", - " near=0.001, far=6.0\n", + " height=h, width=w, fx=2000.0, fy=2000.0, cx=w / 2.0, cy=h / 2.0, near=0.001, far=6.0\n", ")\n", - "scaled_down_intrinsics = b.camera.scale_camera_parameters(intrinsics, 1.0/14.0)\n", + "scaled_down_intrinsics = b.camera.scale_camera_parameters(intrinsics, 1.0 / 14.0)\n", "scaled_down_intrinsics" ] }, @@ -148,11 +151,13 @@ }, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "mesh_paths = []\n", - "for idx in range(1,22):\n", - " mesh_paths.append(os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\"))\n", - "SCALING_FACTOR = 1.0/1000.0\n" + "for idx in range(1, 22):\n", + " mesh_paths.append(\n", + " os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", + " )\n", + "SCALING_FACTOR = 1.0 / 1000.0" ] }, { @@ -173,7 +178,7 @@ "metadata": {}, "outputs": [], "source": [ - "viz = Open3DVisualizer(intrinsics)\n" + "viz = Open3DVisualizer(intrinsics)" ] }, { @@ -198,11 +203,19 @@ "metadata": {}, "outputs": [], "source": [ - "object_poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.6, 0.0]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) for angle in np.linspace(-jnp.pi, jnp.pi, 101)[:-1]])\n", + "object_poses = jnp.array(\n", + " [\n", + " b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 0.6, 0.0]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + " )\n", + " @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " for angle in np.linspace(-jnp.pi, jnp.pi, 101)[:-1]\n", + " ]\n", + ")\n", "# for (i, pose) in enumerate(object_poses):\n", "# b.show_pose(f\"{i}\", pose)" ] @@ -218,7 +231,7 @@ "for i, pose in tqdm(enumerate(object_poses)):\n", " # if i > 0:\n", " # mesh.meshes[0].mesh.transform(b.inv\n", - " \n", + "\n", " # viz.render.scene.add_model(f\"1\", mesh)\n", " rgbd = viz.capture_image(intrinsics, b.t3d.inverse_pose(pose))\n", " images.append(rgbd)\n", @@ -249,6 +262,8 @@ " dot_products = jnp.einsum(\"i, abi->ab\", target_embedding, all_embeddings)\n", " heatmap = dot_products / 2.0 + 0.5\n", " return heatmap\n", + "\n", + "\n", "# heatmap = get_heatmap(img1_embedding[45,45], img2_embedding)\n", "# scaled_up_heatmap = b.utils.resize(heatmap, img2.shape[0], img2.shape[1])\n", "# plt.matshow(heatmap)" @@ -262,14 +277,13 @@ "outputs": [], "source": [ "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('51', '1', bop_ycb_dir)\n", - "\n", + "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\"51\", \"1\", bop_ycb_dir)\n", "\n", "\n", "img2 = rgbd.scale_rgbd(3.0).rgb\n", "img2_embedding = get_embeddings(dinov2_vitg14, img2)\n", "\n", - "b.get_rgb_image(img2)\n" + "b.get_rgb_image(img2)" ] }, { @@ -287,19 +301,21 @@ "IDX1 = 70\n", "img1 = images[IDX1].rgb\n", "img1_embedding = get_embeddings(dinov2_vitg14, img1)\n", - "img1_resized = images[IDX1].scale_rgbd(1.0/14.0).rgb\n", + "img1_resized = images[IDX1].scale_rgbd(1.0 / 14.0).rgb\n", "\n", - "axes[0].imshow(img1 /255.0)\n", + "axes[0].imshow(img1 / 255.0)\n", "axes[1].imshow(img2 / 255.0)\n", + "\n", + "\n", "def onclick(event):\n", " print(\"hello\")\n", " # cos = nn.CosineSimilarity(dim=1)\n", "\n", " x, y = int(np.round(event.xdata)), int(np.round(event.ydata))\n", " x2, y2 = int(np.round(event.xdata / 14.0)), int(np.round(event.ydata / 14.0))\n", - " data.append((x,y))\n", + " data.append((x, y))\n", "\n", - " target_embedding = img1_embedding[y2, x2,:]\n", + " target_embedding = img1_embedding[y2, x2, :]\n", "\n", " # dot_products = jnp.einsum(\"i, abi->ab\", target_embedding, img2_embedding)\n", "\n", @@ -308,19 +324,20 @@ "\n", " axes[0].clear()\n", " axes[1].clear()\n", - " axes[0].imshow(img1 /255.0)\n", + " axes[0].imshow(img1 / 255.0)\n", " axes[1].imshow(img2 / 255.0)\n", - " axes[0].scatter(x, y, c='r', s=10.0)\n", + " axes[0].scatter(x, y, c=\"r\", s=10.0)\n", "\n", " max_yx = np.unravel_index(scaled_up_heatmap.argmax(), scaled_up_heatmap.shape)\n", " # axes[1].axis('off')\n", - " axes[1].imshow(255 * scaled_up_heatmap, alpha=0.45, cmap='viridis')\n", + " axes[1].imshow(255 * scaled_up_heatmap, alpha=0.45, cmap=\"viridis\")\n", " # axes[1].axis('off')\n", - " axes[1].scatter(max_yx[1], max_yx[0], c='r', s=10)\n", + " axes[1].scatter(max_yx[1], max_yx[0], c=\"r\", s=10)\n", " # axes[1].set_title('target image')\n", " # gc.collect()\n", "\n", - "fig.canvas.mpl_connect('button_press_event', onclick)\n", + "\n", + "fig.canvas.mpl_connect(\"button_press_event\", onclick)\n", "plt.show()" ] }, @@ -349,12 +366,14 @@ "x = np.linspace(0, 2 * np.pi)\n", "fig = plt.figure()\n", "ax = fig.add_subplot(1, 1, 1)\n", - "line, = ax.plot(x, np.sin(x))\n", + "(line,) = ax.plot(x, np.sin(x))\n", + "\n", "\n", - "def update(w = 1.0):\n", + "def update(w=1.0):\n", " line.set_ydata(np.sin(w * x))\n", " fig.canvas.draw_idle()\n", "\n", + "\n", "interact(update);" ] }, @@ -388,7 +407,7 @@ "source": [ "num_images = len(images)\n", "num_training_images = 10\n", - "training_indices = jnp.arange(0,num_images-1, num_images // num_training_images)\n", + "training_indices = jnp.arange(0, num_images - 1, num_images // num_training_images)\n", "# b.hstack_images([\n", "# b.get_rgb_image(images[idx].rgb) for idx in training_indices\n", "# ])" @@ -408,44 +427,69 @@ "sparse_descriptors = []\n", "for iteration in range(len(training_indices)):\n", " index = training_indices[iteration]\n", - " index_next = training_indices[(iteration+1) % len(training_indices)]\n", + " index_next = training_indices[(iteration + 1) % len(training_indices)]\n", " print(index, index_next)\n", " keys = jax.random.split(key)[1]\n", - " \n", + "\n", " training_image = images[index]\n", " object_pose = object_poses[index]\n", - " \n", - " scaled_down_training_image = training_image.scale_rgbd(1.0/14.0)\n", + "\n", + " scaled_down_training_image = training_image.scale_rgbd(1.0 / 14.0)\n", " embedding_image = get_embeddings(training_image)\n", " embedding_image_next = get_embeddings(images[index_next])\n", - " \n", - " foreground_mask = (jnp.inf != scaled_down_training_image.depth)\n", + "\n", + " foreground_mask = jnp.inf != scaled_down_training_image.depth\n", " foreground_pixel_coordinates = jnp.transpose(jnp.vstack(jnp.where(foreground_mask)))\n", - " \n", + "\n", " depth = jnp.array(scaled_down_training_image.depth)\n", " depth = depth.at[depth == jnp.inf].set(0.0)\n", - " point_cloud_image = b.t3d.unproject_depth(depth, scaled_down_training_image.intrinsics)\n", - " point_cloud_image_object_frame = b.t3d.apply_transform(point_cloud_image, b.t3d.inverse_pose(object_pose))\n", - " \n", - " scaled_down_training_image_next = images[index_next].scale_rgbd(1.0/14.0)\n", + " point_cloud_image = b.t3d.unproject_depth(\n", + " depth, scaled_down_training_image.intrinsics\n", + " )\n", + " point_cloud_image_object_frame = b.t3d.apply_transform(\n", + " point_cloud_image, b.t3d.inverse_pose(object_pose)\n", + " )\n", + "\n", + " scaled_down_training_image_next = images[index_next].scale_rgbd(1.0 / 14.0)\n", " depth = jnp.array(scaled_down_training_image_next.depth)\n", " depth = depth.at[depth == jnp.inf].set(0.0)\n", - " point_cloud_image_next = b.t3d.unproject_depth(depth, scaled_down_training_image_next.intrinsics)\n", - " point_cloud_image_next_object_frame = b.t3d.apply_transform(point_cloud_image_next, b.t3d.inverse_pose(object_poses[index_next]))\n", - " \n", - " embeddings_subset = embedding_image[foreground_pixel_coordinates[:,0], foreground_pixel_coordinates[:,1],:]\n", - " coordinates_subset = point_cloud_image_object_frame[foreground_pixel_coordinates[:,0], foreground_pixel_coordinates[:,1],:]\n", - " similarity_embedding = jnp.einsum(\"abi, ki->abk\", embedding_image_next, embeddings_subset)\n", + " point_cloud_image_next = b.t3d.unproject_depth(\n", + " depth, scaled_down_training_image_next.intrinsics\n", + " )\n", + " point_cloud_image_next_object_frame = b.t3d.apply_transform(\n", + " point_cloud_image_next, b.t3d.inverse_pose(object_poses[index_next])\n", + " )\n", + "\n", + " embeddings_subset = embedding_image[\n", + " foreground_pixel_coordinates[:, 0], foreground_pixel_coordinates[:, 1], :\n", + " ]\n", + " coordinates_subset = point_cloud_image_object_frame[\n", + " foreground_pixel_coordinates[:, 0], foreground_pixel_coordinates[:, 1], :\n", + " ]\n", + " similarity_embedding = jnp.einsum(\n", + " \"abi, ki->abk\", embedding_image_next, embeddings_subset\n", + " )\n", " best_match = similarity_embedding.argmax(-1)\n", - " distance_to_best_match = jnp.linalg.norm(point_cloud_image_next_object_frame - coordinates_subset[best_match,:], axis=-1)\n", - " \n", + " distance_to_best_match = jnp.linalg.norm(\n", + " point_cloud_image_next_object_frame - coordinates_subset[best_match, :], axis=-1\n", + " )\n", + "\n", " selected = (distance_to_best_match < 0.01) * (similarity_embedding.max(-1) > 0.9)\n", " subset = jnp.unique(best_match[selected])\n", "\n", - "\n", - " _keypoint_embeddings = embedding_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " keypoint_world_coordinates = point_cloud_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " _keypoint_coordinates = b.t3d.apply_transform(keypoint_world_coordinates, b.t3d.inverse_pose(object_pose))\n", + " _keypoint_embeddings = embedding_image[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", + " keypoint_world_coordinates = point_cloud_image[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", + " _keypoint_coordinates = b.t3d.apply_transform(\n", + " keypoint_world_coordinates, b.t3d.inverse_pose(object_pose)\n", + " )\n", "\n", " keypoint_coordinates.append(_keypoint_coordinates)\n", " keypoint_embeddings.append(_keypoint_embeddings)\n", @@ -478,10 +522,10 @@ "# index_next = training_indices[(iteration+1) % len(training_indices)]\n", "# print(index, index_next)\n", "# keys = jax.random.split(key)[1]\n", - " \n", + "\n", "# training_image = images[index]\n", "# object_pose = object_poses[index]\n", - " \n", + "\n", "# scaled_down_training_image = training_image.scale_rgbd(1.0/14.0)\n", "# embedding_image = get_embeddings(training_image)\n", "\n", @@ -499,7 +543,7 @@ "# keypoint_world_coordinates = point_cloud_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", "# _keypoint_coordinates = b.t3d.apply_transform(keypoint_world_coordinates, b.t3d.inverse_pose(object_pose))\n", "# _keypoint_embeddings = embedding_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " \n", + "\n", "# keypoint_coordinates.append(_keypoint_coordinates)\n", "# keypoint_embeddings.append(_keypoint_embeddings)\n", "# del embedding_image\n", @@ -516,24 +560,29 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings):\n", - " point_cloud_img = b.RENDERER.render(pose[None,...], jnp.array([0]))[:,:,:3]\n", - " point_cloud_img_in_object_frame = b.t3d.apply_transform(point_cloud_img, b.t3d.inverse_pose(pose))\n", - "\n", - " distances_to_keypoints = (\n", - " jnp.linalg.norm(point_cloud_img_in_object_frame[:, :,None,...] - keypoint_coordinates[None, None,:,...],\n", - " axis=-1\n", - " ))\n", + " point_cloud_img = b.RENDERER.render(pose[None, ...], jnp.array([0]))[:, :, :3]\n", + " point_cloud_img_in_object_frame = b.t3d.apply_transform(\n", + " point_cloud_img, b.t3d.inverse_pose(pose)\n", + " )\n", + "\n", + " distances_to_keypoints = jnp.linalg.norm(\n", + " point_cloud_img_in_object_frame[:, :, None, ...]\n", + " - keypoint_coordinates[None, None, :, ...],\n", + " axis=-1,\n", + " )\n", " index_of_nearest_keypoint = distances_to_keypoints.argmin(2)\n", " distance_to_nearest_keypoints = distances_to_keypoints.min(2)\n", "\n", " DISTANCE_THRESHOLD = 0.04\n", - " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[...,None]\n", + " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[..., None]\n", " selected_keypoints = keypoint_coordinates[index_of_nearest_keypoint]\n", - " rendered_embeddings_image = keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " rendered_embeddings_image = (\n", + " keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " )\n", " return point_cloud_img, rendered_embeddings_image\n", "\n", + "\n", "vmf_score = lambda q, q_mean, conc: tfp.distributions.VonMisesFisher(\n", " q_mean, conc\n", ").log_prob(q)\n", @@ -545,28 +594,30 @@ "\n", "@functools.partial(\n", " jnp.vectorize,\n", - " signature='(m),(m)->()',\n", + " signature=\"(m),(m)->()\",\n", " excluded=(2,),\n", ")\n", - "def vmf_vectorize(\n", - " embeddings,\n", - " embeddings_mean,\n", - " conc\n", - "):\n", + "def vmf_vectorize(embeddings, embeddings_mean, conc):\n", " return vmf_score(embeddings, embeddings_mean, conc)\n", "\n", "\n", "def score_pose(pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings):\n", - " _,rendered_embedding_image = render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings)\n", + " _, rendered_embedding_image = render_embedding_image(\n", + " pose, keypoint_coordinates, keypoint_embeddings\n", + " )\n", " scores = vmf_vectorize(observed_embeddings, rendered_embedding_image, 1000.0)\n", " return scores\n", "\n", + "\n", "def get_pca(embeddings):\n", - " features_flat = torch.from_numpy(np.array(embeddings).reshape(-1, embeddings.shape[-1]))\n", + " features_flat = torch.from_numpy(\n", + " np.array(embeddings).reshape(-1, embeddings.shape[-1])\n", + " )\n", " U, S, V = torch.pca_lowrank(features_flat - features_flat.mean(0), niter=10)\n", " proj_PCA = jnp.array(V[:, :3])\n", " return proj_PCA\n", "\n", + "\n", "def get_colors(features, proj_V):\n", " features_flat = features.reshape(-1, features.shape[-1])\n", " feat_rgb = features_flat @ proj_V\n", @@ -574,8 +625,9 @@ " feat_rgb = feat_rgb.reshape(features.shape[:-1] + (3,))\n", " return feat_rgb\n", "\n", + "\n", "score_pose_jit = jax.jit(score_pose)\n", - "score_pose_parallel_jit = jax.jit(jax.vmap(score_pose, in_axes=(0, None, None, None )))" + "score_pose_parallel_jit = jax.jit(jax.vmap(score_pose, in_axes=(0, None, None, None)))" ] }, { @@ -592,7 +644,9 @@ "proj_V = get_pca(keypoint_embeddings)\n", "colors = get_colors(keypoint_embeddings, proj_V)\n", "b.clear()\n", - "obj = g.PointCloud(np.transpose(keypoint_coordinates)*10.0, np.transpose(colors), size=0.1)\n", + "obj = g.PointCloud(\n", + " np.transpose(keypoint_coordinates) * 10.0, np.transpose(colors), size=0.1\n", + ")\n", "b.meshcatviz.VISUALIZER[\"2\"].set_object(obj)" ] }, @@ -618,8 +672,10 @@ "metadata": {}, "outputs": [], "source": [ - "b.RENDERER.render(jnp.eye(4)[None,...], jnp.array([0]));\n", - "pc_img, rendered_embedding_image = render_embedding_image(object_poses[0], keypoint_coordinates, keypoint_embeddings);" + "b.RENDERER.render(jnp.eye(4)[None, ...], jnp.array([0]))\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " object_poses[0], keypoint_coordinates, keypoint_embeddings\n", + ")" ] }, { @@ -633,7 +689,7 @@ "source": [ "IDX = 15\n", "test_rgbd = images[IDX]\n", - "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0/14.0)\n", + "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0 / 14.0)\n", "observed_embeddings = get_embeddings(test_rgbd)\n", "# b.get_rgb_image(test_rgbd.rgb)" ] @@ -647,7 +703,14 @@ }, "outputs": [], "source": [ - "posterior = jnp.concatenate([score_pose_parallel_jit(i, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[:,test_rgbd_scaled.depth != jnp.inf].mean(-1) for i in jnp.array_split(object_poses, 10)])\n", + "posterior = jnp.concatenate(\n", + " [\n", + " score_pose_parallel_jit(\n", + " i, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[:, test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " for i in jnp.array_split(object_poses, 10)\n", + " ]\n", + ")\n", "print(posterior.argmax())\n", "best_pose = object_poses[posterior.argmax()]\n", "print(best_pose)" @@ -683,21 +746,22 @@ "observed_embedding_colors = get_colors(observed_embeddings, proj_V)\n", "observed_embeddings_image_viz = b.get_rgb_image(observed_embedding_colors * 255.0)\n", "\n", - "pc_img, rendered_embedding_image = render_embedding_image(best_pose, keypoint_coordinates, keypoint_embeddings)\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " best_pose, keypoint_coordinates, keypoint_embeddings\n", + ")\n", "colors = get_colors(rendered_embedding_image, proj_V)\n", "rgba = jnp.array(b.get_rgb_image(colors * 255.0))\n", "# rgba = rgba.at[pc_img[:,:,2] > intrinsics.far - 0.01, :3].set(255.0)\n", "rerendered_embeddings_viz = b.get_rgb_image(rgba)\n", "\n", - "b.multi_panel([\n", - " b.get_rgb_image(test_rgbd.rgb), \n", - " b.scale_image(observed_embeddings_image_viz, 14.0),\n", - " b.scale_image(rerendered_embeddings_viz, 14.0)\n", - "],labels=[\n", - " \"Observed RGB\",\n", - " \"Embeddings\",\n", - " \"Reconstruction\"\n", - "],label_fontsize=50\n", + "b.multi_panel(\n", + " [\n", + " b.get_rgb_image(test_rgbd.rgb),\n", + " b.scale_image(observed_embeddings_image_viz, 14.0),\n", + " b.scale_image(rerendered_embeddings_viz, 14.0),\n", + " ],\n", + " labels=[\"Observed RGB\", \"Embeddings\", \"Reconstruction\"],\n", + " label_fontsize=50,\n", ").convert(\"RGB\")" ] }, @@ -708,9 +772,11 @@ "metadata": {}, "outputs": [], "source": [ - "random_pose = b.transform_from_pos(jnp.array([0.0, 0.0, 0.6])) @ b.distributions.vmf_jit(jax.random.PRNGKey(40), 0.001)\n", + "random_pose = b.transform_from_pos(\n", + " jnp.array([0.0, 0.0, 0.6])\n", + ") @ b.distributions.vmf_jit(jax.random.PRNGKey(40), 0.001)\n", "test_rgbd = viz.capture_image(intrinsics, b.t3d.inverse_pose(random_pose))\n", - "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0/14.0)\n", + "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0 / 14.0)\n", "observed_embeddings = get_embeddings(test_rgbd)\n", "b.get_rgb_image(test_rgbd.scale_rgbd(0.2).rgb)" ] @@ -742,12 +808,14 @@ "metadata": {}, "outputs": [], "source": [ - "match_scores = jnp.einsum(\"abk,ck\",observed_embeddings, keypoint_embeddings)\n", + "match_scores = jnp.einsum(\"abk,ck\", observed_embeddings, keypoint_embeddings)\n", "top_match = match_scores.max(-1)\n", "top_match_idx = match_scores.argmax(-1)\n", "\n", "THRESHOLD = 0.8\n", - "match_mask = (top_match > THRESHOLD) * (test_rgbd_scaled.depth < test_rgbd_scaled.intrinsics.far)\n", + "match_mask = (top_match > THRESHOLD) * (\n", + " test_rgbd_scaled.depth < test_rgbd_scaled.intrinsics.far\n", + ")\n", "print(match_mask.sum())\n", "b.get_depth_image(1.0 * match_mask)" ] @@ -759,14 +827,16 @@ "metadata": {}, "outputs": [], "source": [ - "observed_point_cloud_image = b.unproject_depth_jit(test_rgbd_scaled.depth, test_rgbd_scaled.intrinsics)\n", + "observed_point_cloud_image = b.unproject_depth_jit(\n", + " test_rgbd_scaled.depth, test_rgbd_scaled.intrinsics\n", + ")\n", "\n", - "observed_match_coordinates = observed_point_cloud_image[match_mask,:]\n", - "model_coordinates = keypoint_coordinates[top_match_idx[match_mask],:]\n", + "observed_match_coordinates = observed_point_cloud_image[match_mask, :]\n", + "model_coordinates = keypoint_coordinates[top_match_idx[match_mask], :]\n", "\n", "b.clear()\n", - "b.show_cloud(\"1\", observed_match_coordinates.reshape(-1,3))\n", - "b.show_cloud(\"2\", model_coordinates.reshape(-1,3), color=b.RED)" + "b.show_cloud(\"1\", observed_match_coordinates.reshape(-1, 3))\n", + "b.show_cloud(\"2\", model_coordinates.reshape(-1, 3), color=b.RED)" ] }, { @@ -777,11 +847,13 @@ "outputs": [], "source": [ "b.clear()\n", - "estimated_pose = b.estimate_transform_between_clouds(model_coordinates, observed_match_coordinates)\n", - "estimated_pose = b.distributions.gaussian_vmf_jit(keys[10],random_pose, 0.1, 10.0)\n", + "estimated_pose = b.estimate_transform_between_clouds(\n", + " model_coordinates, observed_match_coordinates\n", + ")\n", + "estimated_pose = b.distributions.gaussian_vmf_jit(keys[10], random_pose, 0.1, 10.0)\n", "b.show_trimesh(\"mesh\", b.RENDERER.meshes[0])\n", "b.set_pose(\"mesh\", estimated_pose)\n", - "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { @@ -791,8 +863,16 @@ "metadata": {}, "outputs": [], "source": [ - "print(score_pose(random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))\n", - "print(score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))" + "print(\n", + " score_pose(\n", + " random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")\n", + "print(\n", + " score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")" ] }, { @@ -813,16 +893,20 @@ "outputs": [], "source": [ "for _ in range(20):\n", - " potential_poses = gaussian_vmf_parallel(keys,estimated_pose, 0.01, 20000.0)\n", - " current_score = score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", - " scores = score_pose_parallel_jit(potential_poses, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[:,test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " potential_poses = gaussian_vmf_parallel(keys, estimated_pose, 0.01, 20000.0)\n", + " current_score = score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " scores = score_pose_parallel_jit(\n", + " potential_poses, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[:, test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", " if scores.max() > current_score:\n", " estimated_pose = potential_poses[scores.argmax()]\n", " keys = split_jit(keys[0], 100)\n", " print(scores.max(), current_score)\n", " b.show_trimesh(\"mesh\", b.RENDERER.meshes[0])\n", " b.set_pose(\"mesh\", estimated_pose)\n", - " b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + " b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { @@ -832,8 +916,16 @@ "metadata": {}, "outputs": [], "source": [ - "print(score_pose(random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))\n", - "print(score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))" + "print(\n", + " score_pose(\n", + " random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")\n", + "print(\n", + " score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")" ] }, { @@ -854,21 +946,22 @@ "observed_embedding_colors = get_colors(observed_embeddings, proj_V)\n", "observed_embeddings_image_viz = b.get_rgb_image(observed_embedding_colors * 255.0)\n", "\n", - "pc_img, rendered_embedding_image = render_embedding_image(estimated_pose, keypoint_coordinates, keypoint_embeddings)\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings\n", + ")\n", "colors = get_colors(rendered_embedding_image, proj_V)\n", "rgba = jnp.array(b.get_rgb_image(colors * 255.0))\n", "# rgba = rgba.at[pc_img[:,:,2] > intrinsics.far - 0.01, :3].set(255.0)\n", "rerendered_embeddings_viz = b.get_rgb_image(rgba)\n", "\n", - "b.multi_panel([\n", - " b.get_rgb_image(test_rgbd.rgb), \n", - " b.scale_image(observed_embeddings_image_viz, 14.0),\n", - " b.scale_image(rerendered_embeddings_viz, 14.0)\n", - "],labels=[\n", - " \"Observed RGB\",\n", - " \"Embeddings\",\n", - " \"Reconstruction\"\n", - "],label_fontsize=50\n", + "b.multi_panel(\n", + " [\n", + " b.get_rgb_image(test_rgbd.rgb),\n", + " b.scale_image(observed_embeddings_image_viz, 14.0),\n", + " b.scale_image(rerendered_embeddings_viz, 14.0),\n", + " ],\n", + " labels=[\"Observed RGB\", \"Embeddings\", \"Reconstruction\"],\n", + " label_fontsize=50,\n", ").convert(\"RGB\")" ] }, @@ -881,7 +974,7 @@ "source": [ "# b.clear()\n", "# b.show_trimesh(\"mesh\", b.RENDERER.meshes[obj_idx])\n", - "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { diff --git a/scripts/experiments/deeplearning/dino/f3dm_dino_extraction.ipynb b/scripts/experiments/deeplearning/dino/f3dm_dino_extraction.ipynb index ff8960a6..15ddb227 100644 --- a/scripts/experiments/deeplearning/dino/f3dm_dino_extraction.ipynb +++ b/scripts/experiments/deeplearning/dino/f3dm_dino_extraction.ipynb @@ -21,7 +21,7 @@ "import trimesh\n", "import jax\n", "import os\n", - "from tqdm import tqdm\n" + "from tqdm import tqdm" ] }, { @@ -50,9 +50,11 @@ "rgbds = []\n", "images = []\n", "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "for scene_id in [1,21,34]:\n", - " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('52', f'{scene_id}', bop_ycb_dir)\n", - " full_mask = (jnp.stack(masks).sum(0) > 0)\n", + "for scene_id in [1, 21, 34]:\n", + " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\n", + " \"52\", f\"{scene_id}\", bop_ycb_dir\n", + " )\n", + " full_mask = jnp.stack(masks).sum(0) > 0\n", " images.append(b.get_rgb_image(rgbd.rgb * full_mask[..., None]))" ] }, @@ -72,7 +74,7 @@ "source": [ "image_paths = []\n", "for i in range(len(images)):\n", - " name = f'img_{i}.png'\n", + " name = f\"img_{i}.png\"\n", " image_paths.append(name)\n", " images[i].save(name)\n", "print(image_paths)" @@ -175,7 +177,9 @@ "# Visualize the embeddings\n", "plt.figure()\n", "plt.suptitle(\"CLIP (2nd row) and DINO (3rd row) Features PCA\")\n", - "for i, (image_path, clip_pca_, dino_pca_) in enumerate(zip(image_paths, clip_pca, dino_pca)):\n", + "for i, (image_path, clip_pca_, dino_pca_) in enumerate(\n", + " zip(image_paths, clip_pca, dino_pca)\n", + "):\n", " plt.subplot(3, len(image_paths), i + 1)\n", " plt.imshow(Image.open(image_path))\n", " plt.title(os.path.basename(image_path))\n", @@ -192,7 +196,7 @@ "plt.tight_layout()\n", "plt.savefig(\"demo_extract_features.png\")\n", "print(\"Saved plot to demo_extract_features.png\")\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -224,7 +228,8 @@ "%matplotlib widget\n", "IDX1 = 70\n", "img1 = images[IDX1].rgb\n", - "img1_resized = images[IDX1].scale_rgbd(1.0/14.0).rgb\n", + "img1_resized = images[IDX1].scale_rgbd(1.0 / 14.0).rgb\n", + "\n", "\n", "def onclick(event):\n", " print(\"hello\")\n", @@ -232,9 +237,9 @@ "\n", " x, y = int(np.round(event.xdata)), int(np.round(event.ydata))\n", " x2, y2 = int(np.round(event.xdata / 14.0)), int(np.round(event.ydata / 14.0))\n", - " data.append((x,y))\n", + " data.append((x, y))\n", "\n", - " target_embedding = img1_embedding[y2, x2,:]\n", + " target_embedding = img1_embedding[y2, x2, :]\n", "\n", " # dot_products = jnp.einsum(\"i, abi->ab\", target_embedding, img2_embedding)\n", "\n", @@ -243,19 +248,20 @@ "\n", " axes[0].clear()\n", " axes[1].clear()\n", - " axes[0].imshow(img1 /255.0)\n", + " axes[0].imshow(img1 / 255.0)\n", " axes[1].imshow(img2 / 255.0)\n", - " axes[0].scatter(x, y, c='r', s=10.0)\n", + " axes[0].scatter(x, y, c=\"r\", s=10.0)\n", "\n", " max_yx = np.unravel_index(scaled_up_heatmap.argmax(), scaled_up_heatmap.shape)\n", " # axes[1].axis('off')\n", - " axes[1].imshow(255 * scaled_up_heatmap, alpha=0.45, cmap='viridis')\n", + " axes[1].imshow(255 * scaled_up_heatmap, alpha=0.45, cmap=\"viridis\")\n", " # axes[1].axis('off')\n", - " axes[1].scatter(max_yx[1], max_yx[0], c='r', s=10)\n", + " axes[1].scatter(max_yx[1], max_yx[0], c=\"r\", s=10)\n", " # axes[1].set_title('target image')\n", " # gc.collect()\n", "\n", - "fig.canvas.mpl_connect('button_press_event', onclick)\n", + "\n", + "fig.canvas.mpl_connect(\"button_press_event\", onclick)\n", "plt.show()" ] }, diff --git a/scripts/experiments/deeplearning/dino/test_dift.ipynb b/scripts/experiments/deeplearning/dino/test_dift.ipynb index 763d3c7e..6e8fd865 100644 --- a/scripts/experiments/deeplearning/dino/test_dift.ipynb +++ b/scripts/experiments/deeplearning/dino/test_dift.ipynb @@ -23,7 +23,7 @@ "import bayes3d.ycb_loader\n", "import bayes3d.o3d_viz\n", "from tqdm import tqdm\n", - "import open3d as o3d\n" + "import open3d as o3d" ] }, { @@ -95,7 +95,8 @@ " class_labels: Optional[torch.Tensor] = None,\n", " timestep_cond: Optional[torch.Tensor] = None,\n", " attention_mask: Optional[torch.Tensor] = None,\n", - " cross_attention_kwargs: Optional[Dict[str, Any]] = None):\n", + " cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n", + " ):\n", " r\"\"\"\n", " Args:\n", " sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor\n", @@ -157,7 +158,9 @@ "\n", " if self.class_embedding is not None:\n", " if class_labels is None:\n", - " raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n", + " raise ValueError(\n", + " \"class_labels should be provided when num_class_embeds > 0\"\n", + " )\n", "\n", " if self.config.class_embed_type == \"timestep\":\n", " class_labels = self.time_proj(class_labels)\n", @@ -171,7 +174,10 @@ " # 3. down\n", " down_block_res_samples = (sample,)\n", " for downsample_block in self.down_blocks:\n", - " if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n", + " if (\n", + " hasattr(downsample_block, \"has_cross_attention\")\n", + " and downsample_block.has_cross_attention\n", + " ):\n", " sample, res_samples = downsample_block(\n", " hidden_states=sample,\n", " temb=emb,\n", @@ -197,21 +203,25 @@ " # 5. up\n", " up_ft = {}\n", " for i, upsample_block in enumerate(self.up_blocks):\n", - "\n", " if i > np.max(up_ft_indices):\n", " break\n", "\n", " is_final_block = i == len(self.up_blocks) - 1\n", "\n", " res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n", - " down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n", + " down_block_res_samples = down_block_res_samples[\n", + " : -len(upsample_block.resnets)\n", + " ]\n", "\n", " # if we have not reached the final block and need to forward the\n", " # upsample size, we do it here\n", " if not is_final_block and forward_upsample_size:\n", " upsample_size = down_block_res_samples[-1].shape[2:]\n", "\n", - " if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n", + " if (\n", + " hasattr(upsample_block, \"has_cross_attention\")\n", + " and upsample_block.has_cross_attention\n", + " ):\n", " sample = upsample_block(\n", " hidden_states=sample,\n", " temb=emb,\n", @@ -223,16 +233,20 @@ " )\n", " else:\n", " sample = upsample_block(\n", - " hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n", + " hidden_states=sample,\n", + " temb=emb,\n", + " res_hidden_states_tuple=res_samples,\n", + " upsample_size=upsample_size,\n", " )\n", "\n", " if i in up_ft_indices:\n", " up_ft[i] = sample.detach()\n", "\n", " output = {}\n", - " output['up_ft'] = up_ft\n", + " output[\"up_ft\"] = up_ft\n", " return output\n", "\n", + "\n", "class OneStepSDPipeline(StableDiffusionPipeline):\n", " @torch.no_grad()\n", " def __call__(\n", @@ -245,28 +259,36 @@ " prompt_embeds: Optional[torch.FloatTensor] = None,\n", " callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n", " callback_steps: int = 1,\n", - " cross_attention_kwargs: Optional[Dict[str, Any]] = None\n", + " cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n", " ):\n", - "\n", " device = self._execution_device\n", - " latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor\n", + " latents = (\n", + " self.vae.encode(img_tensor).latent_dist.sample()\n", + " * self.vae.config.scaling_factor\n", + " )\n", " t = torch.tensor(t, dtype=torch.long, device=device)\n", " noise = torch.randn_like(latents).to(device)\n", " latents_noisy = self.scheduler.add_noise(latents, noise, t)\n", - " unet_output = self.unet(latents_noisy,\n", - " t,\n", - " up_ft_indices,\n", - " encoder_hidden_states=prompt_embeds,\n", - " cross_attention_kwargs=cross_attention_kwargs)\n", + " unet_output = self.unet(\n", + " latents_noisy,\n", + " t,\n", + " up_ft_indices,\n", + " encoder_hidden_states=prompt_embeds,\n", + " cross_attention_kwargs=cross_attention_kwargs,\n", + " )\n", " return unet_output\n", "\n", "\n", "class SDFeaturizer:\n", - " def __init__(self, sd_id='stabilityai/stable-diffusion-2-1'):\n", + " def __init__(self, sd_id=\"stabilityai/stable-diffusion-2-1\"):\n", " unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder=\"unet\")\n", - " onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)\n", + " onestep_pipe = OneStepSDPipeline.from_pretrained(\n", + " sd_id, unet=unet, safety_checker=None\n", + " )\n", " onestep_pipe.vae.decoder = None\n", - " onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder=\"scheduler\")\n", + " onestep_pipe.scheduler = DDIMScheduler.from_pretrained(\n", + " sd_id, subfolder=\"scheduler\"\n", + " )\n", " gc.collect()\n", " onestep_pipe = onestep_pipe.to(\"cuda\")\n", " onestep_pipe.enable_attention_slicing()\n", @@ -274,26 +296,30 @@ " self.pipe = onestep_pipe\n", "\n", " @torch.no_grad()\n", - " def forward(self,\n", - " img_tensor, # single image, [1,c,h,w]\n", - " prompt,\n", - " t=261,\n", - " up_ft_index=1,\n", - " ensemble_size=8):\n", - " img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w\n", + " def forward(\n", + " self,\n", + " img_tensor, # single image, [1,c,h,w]\n", + " prompt,\n", + " t=261,\n", + " up_ft_index=1,\n", + " ensemble_size=8,\n", + " ):\n", + " img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w\n", " prompt_embeds = self.pipe._encode_prompt(\n", " prompt=prompt,\n", - " device='cuda',\n", + " device=\"cuda\",\n", " num_images_per_prompt=1,\n", - " do_classifier_free_guidance=False) # [1, 77, dim]\n", + " do_classifier_free_guidance=False,\n", + " ) # [1, 77, dim]\n", " prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)\n", " unet_ft_all = self.pipe(\n", " img_tensor=img_tensor,\n", " t=t,\n", " up_ft_indices=[up_ft_index],\n", - " prompt_embeds=prompt_embeds)\n", - " unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w\n", - " unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w\n", + " prompt_embeds=prompt_embeds,\n", + " )\n", + " unet_ft = unet_ft_all[\"up_ft\"][up_ft_index] # ensem, c, h, w\n", + " unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w\n", " return unet_ft" ] }, @@ -307,7 +333,7 @@ "outputs": [], "source": [ "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.ycb_loader.get_test_img('49', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = b.ycb_loader.get_test_img(\"49\", \"1\", bop_ycb_dir)\n", "img1 = b.get_rgb_image(rgbd.rgb)\n", "img1" ] @@ -320,7 +346,7 @@ "outputs": [], "source": [ "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.ycb_loader.get_test_img('51', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = b.ycb_loader.get_test_img(\"51\", \"1\", bop_ycb_dir)\n", "img2 = b.get_rgb_image(rgbd.rgb)\n", "img2" ] @@ -332,7 +358,7 @@ "metadata": {}, "outputs": [], "source": [ - "dift = SDFeaturizer(sd_id='stabilityai/stable-diffusion-2-1')" + "dift = SDFeaturizer(sd_id=\"stabilityai/stable-diffusion-2-1\")" ] }, { @@ -353,10 +379,12 @@ "outputs": [], "source": [ "def get_embeddings(img):\n", - " img = img.convert('RGB')\n", + " img = img.convert(\"RGB\")\n", " img_tensor = (PILToTensor()(img) / 255.0 - 0.5) * 2\n", " output_tensor = dift.forward(img_tensor, prompt=\"object\", ensemble_size=2)\n", - " output = jnp.transpose(jnp.array(output_tensor.cpu().detach().numpy())[0], (1,2,0))\n", + " output = jnp.transpose(\n", + " jnp.array(output_tensor.cpu().detach().numpy())[0], (1, 2, 0)\n", + " )\n", " del img_tensor\n", " del output_tensor\n", " torch.cuda.empty_cache()\n", @@ -364,12 +392,16 @@ "\n", "\n", "def get_embeddings_from_rgbd(rgbd):\n", - " img = b.get_rgb_image(rgbd.rgb).convert('RGB')\n", + " img = b.get_rgb_image(rgbd.rgb).convert(\"RGB\")\n", " return get_embeddings(img)\n", "\n", "\n", - "embeddings1 = jax.image.resize(get_embeddings(img1), (img1.height, img1.width, embeddings.shape[-1]), \"bilinear\")\n", - "embeddings2 = jax.image.resize(get_embeddings(img2), (img1.height, img1.width, embeddings.shape[-1]), \"bilinear\")\n" + "embeddings1 = jax.image.resize(\n", + " get_embeddings(img1), (img1.height, img1.width, embeddings.shape[-1]), \"bilinear\"\n", + ")\n", + "embeddings2 = jax.image.resize(\n", + " get_embeddings(img2), (img1.height, img1.width, embeddings.shape[-1]), \"bilinear\"\n", + ")" ] }, { @@ -381,24 +413,25 @@ "source": [ "class Demo:\n", " def __init__(self, imgs, ft, img_size):\n", - " self.ft = ft # NCHW\n", + " self.ft = ft # NCHW\n", " self.imgs = imgs\n", " self.num_imgs = len(imgs)\n", " self.img_size = img_size\n", "\n", " def plot_img_pairs(self, fig_size=3, alpha=0.45, scatter_size=70):\n", - "\n", - " fig, axes = plt.subplots(1, self.num_imgs, figsize=(fig_size*self.num_imgs, fig_size))\n", + " fig, axes = plt.subplots(\n", + " 1, self.num_imgs, figsize=(fig_size * self.num_imgs, fig_size)\n", + " )\n", "\n", " plt.tight_layout()\n", "\n", " for i in range(self.num_imgs):\n", " axes[i].imshow(self.imgs[i])\n", - " axes[i].axis('off')\n", + " axes[i].axis(\"off\")\n", " if i == 0:\n", - " axes[i].set_title('source image')\n", + " axes[i].set_title(\"source image\")\n", " else:\n", - " axes[i].set_title('target image')\n", + " axes[i].set_title(\"target image\")\n", "\n", " num_channel = self.ft.size(1)\n", " cos = nn.CosineSimilarity(dim=1)\n", @@ -406,46 +439,57 @@ " def onclick(event):\n", " if event.inaxes == axes[0]:\n", " with torch.no_grad():\n", - "\n", " x, y = int(np.round(event.xdata)), int(np.round(event.ydata))\n", "\n", " src_ft = self.ft[0].unsqueeze(0)\n", - " src_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(src_ft)\n", - " src_vec = src_ft[0, :, y, x].view(1, num_channel, 1, 1) # 1, C, 1, 1\n", + " src_ft = nn.Upsample(\n", + " size=(self.img_size, self.img_size), mode=\"bilinear\"\n", + " )(src_ft)\n", + " src_vec = src_ft[0, :, y, x].view(\n", + " 1, num_channel, 1, 1\n", + " ) # 1, C, 1, 1\n", " del src_ft\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", "\n", - " trg_ft = nn.Upsample(size=(self.img_size, self.img_size), mode='bilinear')(self.ft[1:])\n", + " trg_ft = nn.Upsample(\n", + " size=(self.img_size, self.img_size), mode=\"bilinear\"\n", + " )(self.ft[1:])\n", " cos_map = cos(src_vec, trg_ft).cpu().numpy() # N, H, W\n", - " \n", + "\n", " del trg_ft\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", "\n", " axes[0].clear()\n", " axes[0].imshow(self.imgs[0])\n", - " axes[0].axis('off')\n", - " axes[0].scatter(x, y, c='r', s=scatter_size)\n", - " axes[0].set_title('source image')\n", + " axes[0].axis(\"off\")\n", + " axes[0].scatter(x, y, c=\"r\", s=scatter_size)\n", + " axes[0].set_title(\"source image\")\n", "\n", " for i in range(1, self.num_imgs):\n", - " max_yx = np.unravel_index(cos_map[i-1].argmax(), cos_map[i-1].shape)\n", + " max_yx = np.unravel_index(\n", + " cos_map[i - 1].argmax(), cos_map[i - 1].shape\n", + " )\n", " axes[i].clear()\n", "\n", - " heatmap = cos_map[i-1]\n", - " heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) # Normalize to [0, 1]\n", + " heatmap = cos_map[i - 1]\n", + " heatmap = (heatmap - np.min(heatmap)) / (\n", + " np.max(heatmap) - np.min(heatmap)\n", + " ) # Normalize to [0, 1]\n", " axes[i].imshow(self.imgs[i])\n", - " axes[i].imshow(255 * heatmap, alpha=alpha, cmap='viridis')\n", - " axes[i].axis('off')\n", - " axes[i].scatter(max_yx[1].item(), max_yx[0].item(), c='r', s=scatter_size)\n", - " axes[i].set_title('target image')\n", + " axes[i].imshow(255 * heatmap, alpha=alpha, cmap=\"viridis\")\n", + " axes[i].axis(\"off\")\n", + " axes[i].scatter(\n", + " max_yx[1].item(), max_yx[0].item(), c=\"r\", s=scatter_size\n", + " )\n", + " axes[i].set_title(\"target image\")\n", "\n", " del cos_map\n", " del heatmap\n", " gc.collect()\n", "\n", - " fig.canvas.mpl_connect('button_press_event', onclick)\n", + " fig.canvas.mpl_connect(\"button_press_event\", onclick)\n", " plt.show()" ] }, @@ -457,7 +501,7 @@ "outputs": [], "source": [ "%matplotlib widget\n", - "import matplotlib.pyplot as plt\n" + "import matplotlib.pyplot as plt" ] }, { @@ -467,7 +511,7 @@ "metadata": {}, "outputs": [], "source": [ - "scaling_factor= o.shape[0] / training_images[0].rgb.shape[0]\n", + "scaling_factor = o.shape[0] / training_images[0].rgb.shape[0]\n", "scaled_down_intrinsics = b.camera.scale_camera_parameters(intrinsics, scaling_factor)\n", "scaled_down_intrinsics" ] @@ -482,7 +526,7 @@ "outputs": [], "source": [ "num_images = len(training_images)\n", - "training_indices = jnp.arange(0,num_images-1, num_images//4)\n", + "training_indices = jnp.arange(0, num_images - 1, num_images // 4)\n", "\n", "keypoint_coordinates = []\n", "keypoint_embeddings = []\n", @@ -490,34 +534,53 @@ "for idx in tqdm(training_indices):\n", " angle = training_angles[idx]\n", " training_image = training_images[idx]\n", - " pose = b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.6, 0.0]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " pose = b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 0.6, 0.0]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + " ) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", "\n", " scaled_down_training_image = b.scale_rgbd(training_image, scaling_factor)\n", " embeddings = get_embeddings(training_image)\n", " # del embeddings\n", - " foreground_mask = (jnp.inf != scaled_down_training_image.depth)\n", + " foreground_mask = jnp.inf != scaled_down_training_image.depth\n", " foreground_pixel_coordinates = jnp.transpose(jnp.vstack(jnp.where(foreground_mask)))\n", "\n", - " NUM_KEYPOINTS_TO_SELECT = jnp.min(jnp.array([2000,foreground_pixel_coordinates.shape[0]]))\n", - " subset = jax.random.choice(jax.random.PRNGKey(10),foreground_pixel_coordinates.shape[0], shape=(NUM_KEYPOINTS_TO_SELECT,), replace=False)\n", + " NUM_KEYPOINTS_TO_SELECT = jnp.min(\n", + " jnp.array([2000, foreground_pixel_coordinates.shape[0]])\n", + " )\n", + " subset = jax.random.choice(\n", + " jax.random.PRNGKey(10),\n", + " foreground_pixel_coordinates.shape[0],\n", + " shape=(NUM_KEYPOINTS_TO_SELECT,),\n", + " replace=False,\n", + " )\n", "\n", " depth = jnp.array(scaled_down_training_image.depth)\n", " depth = depth.at[depth == jnp.inf].set(0.0)\n", - " point_cloud_image = b.t3d.unproject_depth(depth, scaled_down_training_image.intrinsics)\n", + " point_cloud_image = b.t3d.unproject_depth(\n", + " depth, scaled_down_training_image.intrinsics\n", + " )\n", + "\n", + " keypoint_world_coordinates = point_cloud_image[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", + " _keypoint_coordinates = b.t3d.apply_transform(\n", + " keypoint_world_coordinates, b.t3d.inverse_pose(pose)\n", + " )\n", + " _keypoint_embeddings = embeddings[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", "\n", - " keypoint_world_coordinates = point_cloud_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " _keypoint_coordinates = b.t3d.apply_transform(keypoint_world_coordinates, b.t3d.inverse_pose(pose))\n", - " _keypoint_embeddings = embeddings[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " \n", " keypoint_coordinates.append(_keypoint_coordinates)\n", " keypoint_embeddings.append(_keypoint_embeddings)\n", - " del embeddings\n", - "\n", - " " + " del embeddings" ] }, { @@ -559,36 +622,48 @@ }, "outputs": [], "source": [ - "\n", - "\n", "def render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings):\n", - " point_cloud_img = b.RENDERER.render_single_object(pose, jnp.int32(0))[:,:,:3]\n", - " point_cloud_img_in_object_frame = b.t3d.apply_transform(point_cloud_img, b.t3d.inverse_pose(pose))\n", - "\n", - " distances_to_keypoints = (\n", - " jnp.linalg.norm(point_cloud_img_in_object_frame[:, :,None,...] - keypoint_coordinates[None, None,:,...],\n", - " axis=-1\n", - " ))\n", + " point_cloud_img = b.RENDERER.render_single_object(pose, jnp.int32(0))[:, :, :3]\n", + " point_cloud_img_in_object_frame = b.t3d.apply_transform(\n", + " point_cloud_img, b.t3d.inverse_pose(pose)\n", + " )\n", + "\n", + " distances_to_keypoints = jnp.linalg.norm(\n", + " point_cloud_img_in_object_frame[:, :, None, ...]\n", + " - keypoint_coordinates[None, None, :, ...],\n", + " axis=-1,\n", + " )\n", " index_of_nearest_keypoint = distances_to_keypoints.argmin(2)\n", " distance_to_nearest_keypoints = distances_to_keypoints.min(2)\n", "\n", " DISTANCE_THRESHOLD = 0.04\n", - " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[...,None]\n", + " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[..., None]\n", " selected_keypoints = keypoint_coordinates[index_of_nearest_keypoint]\n", - " rendered_embeddings_image = keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " rendered_embeddings_image = (\n", + " keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " )\n", " return point_cloud_img, rendered_embeddings_image\n", "\n", + "\n", "def score_pose(pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings):\n", - " _,rendered_embedding_image = render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings)\n", - " dot_products = jnp.einsum(\"abi,abi->ab\", rendered_embedding_image, observed_embeddings)\n", + " _, rendered_embedding_image = render_embedding_image(\n", + " pose, keypoint_coordinates, keypoint_embeddings\n", + " )\n", + " dot_products = jnp.einsum(\n", + " \"abi,abi->ab\", rendered_embedding_image, observed_embeddings\n", + " )\n", " return dot_products.mean()\n", "\n", + "\n", "def get_pca(embeddings):\n", - " features_flat = torch.from_numpy(np.array(embeddings).reshape(-1, embeddings.shape[-1]))\n", + " features_flat = torch.from_numpy(\n", + " np.array(embeddings).reshape(-1, embeddings.shape[-1])\n", + " )\n", " U, S, V = torch.pca_lowrank(features_flat - features_flat.mean(0), niter=10)\n", " proj_PCA = jnp.array(V[:, :3])\n", " return proj_PCA\n", "\n", + "\n", "def get_colors(features, proj_V):\n", " features_flat = features.reshape(-1, features.shape[-1])\n", " feat_rgb = features_flat @ proj_V\n", @@ -596,6 +671,7 @@ " feat_rgb = feat_rgb.reshape(features.shape[:-1] + (3,))\n", " return feat_rgb\n", "\n", + "\n", "score_pose_jit = jax.jit(score_pose)" ] }, @@ -608,19 +684,19 @@ }, "outputs": [], "source": [ - "\n", "angles = jnp.linspace(-jnp.pi, jnp.pi, 300)\n", - "angle_to_pose = lambda angle : b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", + "angle_to_pose = lambda angle: b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", " jnp.array([0.0, 0.6, 0.0]),\n", " jnp.array([0.0, 0.0, 0.0]),\n", " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " )\n", + ") @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", "scorer = lambda angle, observed_embeddings: score_pose(\n", - " angle_to_pose(angle),\n", - " keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " angle_to_pose(angle), keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", ")\n", "scorer_jit = jax.jit(scorer)\n", - "scorer_parallel_jit = jax.jit(jax.vmap(scorer, in_axes=(0,None)))\n" + "scorer_parallel_jit = jax.jit(jax.vmap(scorer, in_axes=(0, None)))" ] }, { @@ -659,7 +735,9 @@ }, "outputs": [], "source": [ - "posterior = jnp.concatenate([scorer_parallel_jit(i, observed_embeddings) for i in jnp.array_split(angles, 10)])\n", + "posterior = jnp.concatenate(\n", + " [scorer_parallel_jit(i, observed_embeddings) for i in jnp.array_split(angles, 10)]\n", + ")\n", "best_angle = angles[posterior.argmax()]\n", "print(best_angle)\n", "best_pose = angle_to_pose(best_angle)" @@ -673,7 +751,7 @@ "outputs": [], "source": [ "colors = get_colors(observed_embeddings, proj_V)\n", - "embedding_image = b.scale_image(b.get_rgb_image(colors * 255.0),14.0)\n", + "embedding_image = b.scale_image(b.get_rgb_image(colors * 255.0), 14.0)\n", "embedding_image" ] }, @@ -686,11 +764,13 @@ }, "outputs": [], "source": [ - "pc_img, img = render_embedding_image(angle_to_pose(best_angle), keypoint_coordinates, keypoint_embeddings)\n", + "pc_img, img = render_embedding_image(\n", + " angle_to_pose(best_angle), keypoint_coordinates, keypoint_embeddings\n", + ")\n", "colors = get_colors(observed_embeddings, proj_V)\n", "rgba = jnp.array(b.get_rgb_image(colors * 255.0))\n", - "rgba = rgba.at[pc_img[:,:,2] > intrinsics.far - 0.01, :3].set(255.0)\n", - "rerendered_embeddings = b.scale_image(b.get_rgb_image(rgba),14.0)\n", + "rgba = rgba.at[pc_img[:, :, 2] > intrinsics.far - 0.01, :3].set(255.0)\n", + "rerendered_embeddings = b.scale_image(b.get_rgb_image(rgba), 14.0)\n", "rerendered_embeddings" ] }, @@ -703,12 +783,7 @@ }, "outputs": [], "source": [ - "b.multi_panel([\n", - " b.get_rgb_image(test_rgbd.rgb),\n", - " embedding_image,\n", - " rerendered_embeddings\n", - "]\n", - " )" + "b.multi_panel([b.get_rgb_image(test_rgbd.rgb), embedding_image, rerendered_embeddings])" ] }, { @@ -725,7 +800,9 @@ "proj_V = get_pca(keypoint_embeddings)\n", "colors = get_colors(keypoint_embeddings, proj_V)\n", "b.clear()\n", - "obj = g.PointCloud(np.transpose(keypoint_coordinates)*30.0, np.transpose(colors), size=0.1)\n", + "obj = g.PointCloud(\n", + " np.transpose(keypoint_coordinates) * 30.0, np.transpose(colors), size=0.1\n", + ")\n", "b.meshcatviz.VISUALIZER[\"2\"].set_object(obj)" ] }, @@ -740,7 +817,7 @@ "source": [ "b.setup_renderer(test_rgbd.intrinsics)\n", "# b.RENDERER.add_mesh_from_file(mesh_filename, scaling_factor=SCALING_FACTOR)\n", - "point_cloud_img = b.RENDERER.render_single_object(best_pose, jnp.int32(0))[:,:,:3]" + "point_cloud_img = b.RENDERER.render_single_object(best_pose, jnp.int32(0))[:, :, :3]" ] }, { @@ -752,11 +829,13 @@ }, "outputs": [], "source": [ - "mask = (test_rgbd.intrinsics.near < point_cloud_img[:,:,2]) * (point_cloud_img[:,:,2] < test_rgbd.intrinsics.far)\n", - "print(point_cloud_img[:,:,2][mask].min(), point_cloud_img[:,:,2][mask].max())\n", + "mask = (test_rgbd.intrinsics.near < point_cloud_img[:, :, 2]) * (\n", + " point_cloud_img[:, :, 2] < test_rgbd.intrinsics.far\n", + ")\n", + "print(point_cloud_img[:, :, 2][mask].min(), point_cloud_img[:, :, 2][mask].max())\n", "b.get_depth_image(1.0 * mask)\n", - "img = jnp.array(b.get_depth_image(point_cloud_img[:,:,2], min=0.46, max=0.65))\n", - "img = img.at[jnp.invert(mask) , :3].set(255.0)\n", + "img = jnp.array(b.get_depth_image(point_cloud_img[:, :, 2], min=0.46, max=0.65))\n", + "img = img.at[jnp.invert(mask), :3].set(255.0)\n", "b.get_rgb_image(img)" ] }, @@ -782,20 +861,25 @@ "outputs": [], "source": [ "pose = best_pose\n", - "point_cloud_img = b.RENDERER.render_single_object(pose, jnp.int32(0))[:,:,:3]\n", - "point_cloud_img_in_object_frame = b.t3d.apply_transform(point_cloud_img, b.t3d.inverse_pose(pose))\n", + "point_cloud_img = b.RENDERER.render_single_object(pose, jnp.int32(0))[:, :, :3]\n", + "point_cloud_img_in_object_frame = b.t3d.apply_transform(\n", + " point_cloud_img, b.t3d.inverse_pose(pose)\n", + ")\n", "\n", - "distances_to_keypoints = (\n", - " jnp.linalg.norm(point_cloud_img_in_object_frame[:, :,None,...] - keypoint_coordinates[None, None,:,...],\n", - " axis=-1\n", - "))\n", + "distances_to_keypoints = jnp.linalg.norm(\n", + " point_cloud_img_in_object_frame[:, :, None, ...]\n", + " - keypoint_coordinates[None, None, :, ...],\n", + " axis=-1,\n", + ")\n", "index_of_nearest_keypoint = distances_to_keypoints.argmin(2)\n", "distance_to_nearest_keypoints = distances_to_keypoints.min(2)\n", "\n", "DISTANCE_THRESHOLD = 0.2\n", - "valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[...,None]\n", + "valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[..., None]\n", "selected_keypoints = keypoint_coordinates[index_of_nearest_keypoint]\n", - "rendered_embeddings_image = keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + "rendered_embeddings_image = (\n", + " keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + ")\n", "\n", "colors = get_colors(rendered_embeddings_image, proj_V)\n", "b.get_rgb_image(colors * 255.0)" diff --git a/scripts/experiments/deeplearning/dino/test_dino.ipynb b/scripts/experiments/deeplearning/dino/test_dino.ipynb index 371b8d9d..4018bd7a 100644 --- a/scripts/experiments/deeplearning/dino/test_dino.ipynb +++ b/scripts/experiments/deeplearning/dino/test_dino.ipynb @@ -23,7 +23,7 @@ "import bayes3d.utils.ycb_loader\n", "from bayes3d.viz.open3dviz import Open3DVisualizer\n", "from tqdm import tqdm\n", - "import open3d as o3d\n" + "import open3d as o3d" ] }, { @@ -48,7 +48,7 @@ "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')\n", + "dinov2_vitg14 = torch.hub.load(\"facebookresearch/dinov2\", \"dinov2_vits14\")\n", "dino = dinov2_vitg14.to(device) # Same issue with larger model" ] }, @@ -62,19 +62,26 @@ "outputs": [], "source": [ "def get_embeddings(rgbd):\n", - " img = b.get_rgb_image(rgbd.rgb).convert('RGB')\n", + " img = b.get_rgb_image(rgbd.rgb).convert(\"RGB\")\n", " patch_w, patch_h = np.array(img.size) // 14\n", - " transform = T.Compose([\n", - " T.GaussianBlur(9, sigma=(0.1, 2.0)),\n", - " T.Resize((patch_h * 14, patch_w * 14)),\n", - " T.CenterCrop((patch_h * 14, patch_w * 14)),\n", - " T.ToTensor(),\n", - " T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n", - " ])\n", + " transform = T.Compose(\n", + " [\n", + " T.GaussianBlur(9, sigma=(0.1, 2.0)),\n", + " T.Resize((patch_h * 14, patch_w * 14)),\n", + " T.CenterCrop((patch_h * 14, patch_w * 14)),\n", + " T.ToTensor(),\n", + " T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n", + " ]\n", + " )\n", " tensor = transform(img)[:3].unsqueeze(0).to(device)\n", " with torch.no_grad():\n", " features_dict = dinov2_vitg14.forward_features(tensor)\n", - " features = features_dict['x_norm_patchtokens'][0].reshape((patch_h, patch_w, 384)).permute(2, 0, 1).unsqueeze(0)\n", + " features = (\n", + " features_dict[\"x_norm_patchtokens\"][0]\n", + " .reshape((patch_h, patch_w, 384))\n", + " .permute(2, 0, 1)\n", + " .unsqueeze(0)\n", + " )\n", " img_feat_norm = torch.nn.functional.normalize(features, dim=1)\n", " output = jnp.array(img_feat_norm.cpu().detach().numpy())[0]\n", " del img_feat_norm\n", @@ -82,7 +89,7 @@ " del tensor\n", " del features_dict\n", " torch.cuda.empty_cache()\n", - " return jnp.transpose(output, (1,2,0))" + " return jnp.transpose(output, (1, 2, 0))" ] }, { @@ -94,15 +101,11 @@ }, "outputs": [], "source": [ - "w,h = 1400,1400\n", + "w, h = 1400, 1400\n", "intrinsics = b.Intrinsics(\n", - " height=h,\n", - " width=w,\n", - " fx=2000.0, fy=2000.0,\n", - " cx=w/2.0, cy=h/2.0,\n", - " near=0.001, far=6.0\n", + " height=h, width=w, fx=2000.0, fy=2000.0, cx=w / 2.0, cy=h / 2.0, near=0.001, far=6.0\n", ")\n", - "scaled_down_intrinsics = b.camera.scale_camera_parameters(intrinsics, 1.0/14.0)\n", + "scaled_down_intrinsics = b.camera.scale_camera_parameters(intrinsics, 1.0 / 14.0)\n", "scaled_down_intrinsics" ] }, @@ -127,12 +130,14 @@ }, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "mesh_paths = []\n", - "for idx in range(1,22):\n", - " mesh_paths.append(os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\"))\n", - "SCALING_FACTOR = 1.0/1000.0\n", - "obj_idx = 1\n" + "for idx in range(1, 22):\n", + " mesh_paths.append(\n", + " os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", + " )\n", + "SCALING_FACTOR = 1.0 / 1000.0\n", + "obj_idx = 1" ] }, { @@ -153,7 +158,7 @@ "metadata": {}, "outputs": [], "source": [ - "viz = Open3DVisualizer(intrinsics)\n" + "viz = Open3DVisualizer(intrinsics)" ] }, { @@ -177,11 +182,19 @@ "metadata": {}, "outputs": [], "source": [ - "object_poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.6, 0.0]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) for angle in np.linspace(-jnp.pi, jnp.pi, 101)[:-1]])\n", + "object_poses = jnp.array(\n", + " [\n", + " b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 0.6, 0.0]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + " )\n", + " @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " for angle in np.linspace(-jnp.pi, jnp.pi, 101)[:-1]\n", + " ]\n", + ")\n", "# for (i, pose) in enumerate(object_poses):\n", "# b.show_pose(f\"{i}\", pose)" ] @@ -197,7 +210,7 @@ "for i, pose in tqdm(enumerate(object_poses)):\n", " # if i > 0:\n", " # mesh.meshes[0].mesh.transform(b.inv\n", - " \n", + "\n", " # viz.render.scene.add_model(f\"1\", mesh)\n", " rgbd = viz.capture_image(intrinsics, b.t3d.inverse_pose(pose))\n", " images.append(rgbd)\n", @@ -245,7 +258,7 @@ "source": [ "num_images = len(images)\n", "num_training_images = 10\n", - "training_indices = jnp.arange(0,num_images-1, num_images // num_training_images)\n", + "training_indices = jnp.arange(0, num_images - 1, num_images // num_training_images)\n", "# b.hstack_images([\n", "# b.get_rgb_image(images[idx].rgb) for idx in training_indices\n", "# ])" @@ -265,44 +278,69 @@ "sparse_descriptors = []\n", "for iteration in range(len(training_indices)):\n", " index = training_indices[iteration]\n", - " index_next = training_indices[(iteration+1) % len(training_indices)]\n", + " index_next = training_indices[(iteration + 1) % len(training_indices)]\n", " print(index, index_next)\n", " keys = jax.random.split(key)[1]\n", - " \n", + "\n", " training_image = images[index]\n", " object_pose = object_poses[index]\n", - " \n", - " scaled_down_training_image = training_image.scale_rgbd(1.0/14.0)\n", + "\n", + " scaled_down_training_image = training_image.scale_rgbd(1.0 / 14.0)\n", " embedding_image = get_embeddings(training_image)\n", " embedding_image_next = get_embeddings(images[index_next])\n", - " \n", - " foreground_mask = (jnp.inf != scaled_down_training_image.depth)\n", + "\n", + " foreground_mask = jnp.inf != scaled_down_training_image.depth\n", " foreground_pixel_coordinates = jnp.transpose(jnp.vstack(jnp.where(foreground_mask)))\n", - " \n", + "\n", " depth = jnp.array(scaled_down_training_image.depth)\n", " depth = depth.at[depth == jnp.inf].set(0.0)\n", - " point_cloud_image = b.t3d.unproject_depth(depth, scaled_down_training_image.intrinsics)\n", - " point_cloud_image_object_frame = b.t3d.apply_transform(point_cloud_image, b.t3d.inverse_pose(object_pose))\n", - " \n", - " scaled_down_training_image_next = images[index_next].scale_rgbd(1.0/14.0)\n", + " point_cloud_image = b.t3d.unproject_depth(\n", + " depth, scaled_down_training_image.intrinsics\n", + " )\n", + " point_cloud_image_object_frame = b.t3d.apply_transform(\n", + " point_cloud_image, b.t3d.inverse_pose(object_pose)\n", + " )\n", + "\n", + " scaled_down_training_image_next = images[index_next].scale_rgbd(1.0 / 14.0)\n", " depth = jnp.array(scaled_down_training_image_next.depth)\n", " depth = depth.at[depth == jnp.inf].set(0.0)\n", - " point_cloud_image_next = b.t3d.unproject_depth(depth, scaled_down_training_image_next.intrinsics)\n", - " point_cloud_image_next_object_frame = b.t3d.apply_transform(point_cloud_image_next, b.t3d.inverse_pose(object_poses[index_next]))\n", - " \n", - " embeddings_subset = embedding_image[foreground_pixel_coordinates[:,0], foreground_pixel_coordinates[:,1],:]\n", - " coordinates_subset = point_cloud_image_object_frame[foreground_pixel_coordinates[:,0], foreground_pixel_coordinates[:,1],:]\n", - " similarity_embedding = jnp.einsum(\"abi, ki->abk\", embedding_image_next, embeddings_subset)\n", + " point_cloud_image_next = b.t3d.unproject_depth(\n", + " depth, scaled_down_training_image_next.intrinsics\n", + " )\n", + " point_cloud_image_next_object_frame = b.t3d.apply_transform(\n", + " point_cloud_image_next, b.t3d.inverse_pose(object_poses[index_next])\n", + " )\n", + "\n", + " embeddings_subset = embedding_image[\n", + " foreground_pixel_coordinates[:, 0], foreground_pixel_coordinates[:, 1], :\n", + " ]\n", + " coordinates_subset = point_cloud_image_object_frame[\n", + " foreground_pixel_coordinates[:, 0], foreground_pixel_coordinates[:, 1], :\n", + " ]\n", + " similarity_embedding = jnp.einsum(\n", + " \"abi, ki->abk\", embedding_image_next, embeddings_subset\n", + " )\n", " best_match = similarity_embedding.argmax(-1)\n", - " distance_to_best_match = jnp.linalg.norm(point_cloud_image_next_object_frame - coordinates_subset[best_match,:], axis=-1)\n", - " \n", + " distance_to_best_match = jnp.linalg.norm(\n", + " point_cloud_image_next_object_frame - coordinates_subset[best_match, :], axis=-1\n", + " )\n", + "\n", " selected = (distance_to_best_match < 0.01) * (similarity_embedding.max(-1) > 0.9)\n", " subset = jnp.unique(best_match[selected])\n", "\n", - "\n", - " _keypoint_embeddings = embedding_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " keypoint_world_coordinates = point_cloud_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " _keypoint_coordinates = b.t3d.apply_transform(keypoint_world_coordinates, b.t3d.inverse_pose(object_pose))\n", + " _keypoint_embeddings = embedding_image[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", + " keypoint_world_coordinates = point_cloud_image[\n", + " foreground_pixel_coordinates[subset, 0],\n", + " foreground_pixel_coordinates[subset, 1],\n", + " :,\n", + " ]\n", + " _keypoint_coordinates = b.t3d.apply_transform(\n", + " keypoint_world_coordinates, b.t3d.inverse_pose(object_pose)\n", + " )\n", "\n", " keypoint_coordinates.append(_keypoint_coordinates)\n", " keypoint_embeddings.append(_keypoint_embeddings)\n", @@ -335,10 +373,10 @@ "# index_next = training_indices[(iteration+1) % len(training_indices)]\n", "# print(index, index_next)\n", "# keys = jax.random.split(key)[1]\n", - " \n", + "\n", "# training_image = images[index]\n", "# object_pose = object_poses[index]\n", - " \n", + "\n", "# scaled_down_training_image = training_image.scale_rgbd(1.0/14.0)\n", "# embedding_image = get_embeddings(training_image)\n", "\n", @@ -356,7 +394,7 @@ "# keypoint_world_coordinates = point_cloud_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", "# _keypoint_coordinates = b.t3d.apply_transform(keypoint_world_coordinates, b.t3d.inverse_pose(object_pose))\n", "# _keypoint_embeddings = embedding_image[foreground_pixel_coordinates[subset,0], foreground_pixel_coordinates[subset,1],:]\n", - " \n", + "\n", "# keypoint_coordinates.append(_keypoint_coordinates)\n", "# keypoint_embeddings.append(_keypoint_embeddings)\n", "# del embedding_image\n", @@ -373,24 +411,29 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings):\n", - " point_cloud_img = b.RENDERER.render(pose[None,...], jnp.array([0]))[:,:,:3]\n", - " point_cloud_img_in_object_frame = b.t3d.apply_transform(point_cloud_img, b.t3d.inverse_pose(pose))\n", - "\n", - " distances_to_keypoints = (\n", - " jnp.linalg.norm(point_cloud_img_in_object_frame[:, :,None,...] - keypoint_coordinates[None, None,:,...],\n", - " axis=-1\n", - " ))\n", + " point_cloud_img = b.RENDERER.render(pose[None, ...], jnp.array([0]))[:, :, :3]\n", + " point_cloud_img_in_object_frame = b.t3d.apply_transform(\n", + " point_cloud_img, b.t3d.inverse_pose(pose)\n", + " )\n", + "\n", + " distances_to_keypoints = jnp.linalg.norm(\n", + " point_cloud_img_in_object_frame[:, :, None, ...]\n", + " - keypoint_coordinates[None, None, :, ...],\n", + " axis=-1,\n", + " )\n", " index_of_nearest_keypoint = distances_to_keypoints.argmin(2)\n", " distance_to_nearest_keypoints = distances_to_keypoints.min(2)\n", "\n", " DISTANCE_THRESHOLD = 0.04\n", - " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[...,None]\n", + " valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[..., None]\n", " selected_keypoints = keypoint_coordinates[index_of_nearest_keypoint]\n", - " rendered_embeddings_image = keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " rendered_embeddings_image = (\n", + " keypoint_embeddings[index_of_nearest_keypoint] * valid_match_mask\n", + " )\n", " return point_cloud_img, rendered_embeddings_image\n", "\n", + "\n", "vmf_score = lambda q, q_mean, conc: tfp.distributions.VonMisesFisher(\n", " q_mean, conc\n", ").log_prob(q)\n", @@ -402,28 +445,30 @@ "\n", "@functools.partial(\n", " jnp.vectorize,\n", - " signature='(m),(m)->()',\n", + " signature=\"(m),(m)->()\",\n", " excluded=(2,),\n", ")\n", - "def vmf_vectorize(\n", - " embeddings,\n", - " embeddings_mean,\n", - " conc\n", - "):\n", + "def vmf_vectorize(embeddings, embeddings_mean, conc):\n", " return vmf_score(embeddings, embeddings_mean, conc)\n", "\n", "\n", "def score_pose(pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings):\n", - " _,rendered_embedding_image = render_embedding_image(pose, keypoint_coordinates, keypoint_embeddings)\n", + " _, rendered_embedding_image = render_embedding_image(\n", + " pose, keypoint_coordinates, keypoint_embeddings\n", + " )\n", " scores = vmf_vectorize(observed_embeddings, rendered_embedding_image, 1000.0)\n", " return scores\n", "\n", + "\n", "def get_pca(embeddings):\n", - " features_flat = torch.from_numpy(np.array(embeddings).reshape(-1, embeddings.shape[-1]))\n", + " features_flat = torch.from_numpy(\n", + " np.array(embeddings).reshape(-1, embeddings.shape[-1])\n", + " )\n", " U, S, V = torch.pca_lowrank(features_flat - features_flat.mean(0), niter=10)\n", " proj_PCA = jnp.array(V[:, :3])\n", " return proj_PCA\n", "\n", + "\n", "def get_colors(features, proj_V):\n", " features_flat = features.reshape(-1, features.shape[-1])\n", " feat_rgb = features_flat @ proj_V\n", @@ -431,8 +476,9 @@ " feat_rgb = feat_rgb.reshape(features.shape[:-1] + (3,))\n", " return feat_rgb\n", "\n", + "\n", "score_pose_jit = jax.jit(score_pose)\n", - "score_pose_parallel_jit = jax.jit(jax.vmap(score_pose, in_axes=(0, None, None, None )))" + "score_pose_parallel_jit = jax.jit(jax.vmap(score_pose, in_axes=(0, None, None, None)))" ] }, { @@ -449,7 +495,9 @@ "proj_V = get_pca(keypoint_embeddings)\n", "colors = get_colors(keypoint_embeddings, proj_V)\n", "b.clear()\n", - "obj = g.PointCloud(np.transpose(keypoint_coordinates)*10.0, np.transpose(colors), size=0.1)\n", + "obj = g.PointCloud(\n", + " np.transpose(keypoint_coordinates) * 10.0, np.transpose(colors), size=0.1\n", + ")\n", "b.meshcatviz.VISUALIZER[\"2\"].set_object(obj)" ] }, @@ -475,8 +523,10 @@ "metadata": {}, "outputs": [], "source": [ - "b.RENDERER.render(jnp.eye(4)[None,...], jnp.array([0]));\n", - "pc_img, rendered_embedding_image = render_embedding_image(object_poses[0], keypoint_coordinates, keypoint_embeddings);" + "b.RENDERER.render(jnp.eye(4)[None, ...], jnp.array([0]))\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " object_poses[0], keypoint_coordinates, keypoint_embeddings\n", + ")" ] }, { @@ -490,7 +540,7 @@ "source": [ "IDX = 15\n", "test_rgbd = images[IDX]\n", - "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0/14.0)\n", + "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0 / 14.0)\n", "observed_embeddings = get_embeddings(test_rgbd)\n", "# b.get_rgb_image(test_rgbd.rgb)" ] @@ -504,7 +554,14 @@ }, "outputs": [], "source": [ - "posterior = jnp.concatenate([score_pose_parallel_jit(i, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[:,test_rgbd_scaled.depth != jnp.inf].mean(-1) for i in jnp.array_split(object_poses, 10)])\n", + "posterior = jnp.concatenate(\n", + " [\n", + " score_pose_parallel_jit(\n", + " i, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[:, test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " for i in jnp.array_split(object_poses, 10)\n", + " ]\n", + ")\n", "print(posterior.argmax())\n", "best_pose = object_poses[posterior.argmax()]\n", "print(best_pose)" @@ -540,21 +597,22 @@ "observed_embedding_colors = get_colors(observed_embeddings, proj_V)\n", "observed_embeddings_image_viz = b.get_rgb_image(observed_embedding_colors * 255.0)\n", "\n", - "pc_img, rendered_embedding_image = render_embedding_image(best_pose, keypoint_coordinates, keypoint_embeddings)\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " best_pose, keypoint_coordinates, keypoint_embeddings\n", + ")\n", "colors = get_colors(rendered_embedding_image, proj_V)\n", "rgba = jnp.array(b.get_rgb_image(colors * 255.0))\n", "# rgba = rgba.at[pc_img[:,:,2] > intrinsics.far - 0.01, :3].set(255.0)\n", "rerendered_embeddings_viz = b.get_rgb_image(rgba)\n", "\n", - "b.multi_panel([\n", - " b.get_rgb_image(test_rgbd.rgb), \n", - " b.scale_image(observed_embeddings_image_viz, 14.0),\n", - " b.scale_image(rerendered_embeddings_viz, 14.0)\n", - "],labels=[\n", - " \"Observed RGB\",\n", - " \"Embeddings\",\n", - " \"Reconstruction\"\n", - "],label_fontsize=50\n", + "b.multi_panel(\n", + " [\n", + " b.get_rgb_image(test_rgbd.rgb),\n", + " b.scale_image(observed_embeddings_image_viz, 14.0),\n", + " b.scale_image(rerendered_embeddings_viz, 14.0),\n", + " ],\n", + " labels=[\"Observed RGB\", \"Embeddings\", \"Reconstruction\"],\n", + " label_fontsize=50,\n", ").convert(\"RGB\")" ] }, @@ -565,9 +623,11 @@ "metadata": {}, "outputs": [], "source": [ - "random_pose = b.transform_from_pos(jnp.array([0.0, 0.0, 0.6])) @ b.distributions.vmf_jit(jax.random.PRNGKey(40), 0.001)\n", + "random_pose = b.transform_from_pos(\n", + " jnp.array([0.0, 0.0, 0.6])\n", + ") @ b.distributions.vmf_jit(jax.random.PRNGKey(40), 0.001)\n", "test_rgbd = viz.capture_image(intrinsics, b.t3d.inverse_pose(random_pose))\n", - "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0/14.0)\n", + "test_rgbd_scaled = test_rgbd.scale_rgbd(1.0 / 14.0)\n", "observed_embeddings = get_embeddings(test_rgbd)\n", "b.get_rgb_image(test_rgbd.scale_rgbd(0.2).rgb)" ] @@ -599,12 +659,14 @@ "metadata": {}, "outputs": [], "source": [ - "match_scores = jnp.einsum(\"abk,ck\",observed_embeddings, keypoint_embeddings)\n", + "match_scores = jnp.einsum(\"abk,ck\", observed_embeddings, keypoint_embeddings)\n", "top_match = match_scores.max(-1)\n", "top_match_idx = match_scores.argmax(-1)\n", "\n", "THRESHOLD = 0.8\n", - "match_mask = (top_match > THRESHOLD) * (test_rgbd_scaled.depth < test_rgbd_scaled.intrinsics.far)\n", + "match_mask = (top_match > THRESHOLD) * (\n", + " test_rgbd_scaled.depth < test_rgbd_scaled.intrinsics.far\n", + ")\n", "print(match_mask.sum())\n", "b.get_depth_image(1.0 * match_mask)" ] @@ -616,14 +678,16 @@ "metadata": {}, "outputs": [], "source": [ - "observed_point_cloud_image = b.unproject_depth_jit(test_rgbd_scaled.depth, test_rgbd_scaled.intrinsics)\n", + "observed_point_cloud_image = b.unproject_depth_jit(\n", + " test_rgbd_scaled.depth, test_rgbd_scaled.intrinsics\n", + ")\n", "\n", - "observed_match_coordinates = observed_point_cloud_image[match_mask,:]\n", - "model_coordinates = keypoint_coordinates[top_match_idx[match_mask],:]\n", + "observed_match_coordinates = observed_point_cloud_image[match_mask, :]\n", + "model_coordinates = keypoint_coordinates[top_match_idx[match_mask], :]\n", "\n", "b.clear()\n", - "b.show_cloud(\"1\", observed_match_coordinates.reshape(-1,3))\n", - "b.show_cloud(\"2\", model_coordinates.reshape(-1,3), color=b.RED)" + "b.show_cloud(\"1\", observed_match_coordinates.reshape(-1, 3))\n", + "b.show_cloud(\"2\", model_coordinates.reshape(-1, 3), color=b.RED)" ] }, { @@ -634,11 +698,13 @@ "outputs": [], "source": [ "b.clear()\n", - "estimated_pose = b.estimate_transform_between_clouds(model_coordinates, observed_match_coordinates)\n", - "estimated_pose = b.distributions.gaussian_vmf_jit(keys[10],random_pose, 0.1, 10.0)\n", + "estimated_pose = b.estimate_transform_between_clouds(\n", + " model_coordinates, observed_match_coordinates\n", + ")\n", + "estimated_pose = b.distributions.gaussian_vmf_jit(keys[10], random_pose, 0.1, 10.0)\n", "b.show_trimesh(\"mesh\", b.RENDERER.meshes[0])\n", "b.set_pose(\"mesh\", estimated_pose)\n", - "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { @@ -648,8 +714,16 @@ "metadata": {}, "outputs": [], "source": [ - "print(score_pose(random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))\n", - "print(score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))" + "print(\n", + " score_pose(\n", + " random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")\n", + "print(\n", + " score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")" ] }, { @@ -670,16 +744,20 @@ "outputs": [], "source": [ "for _ in range(20):\n", - " potential_poses = gaussian_vmf_parallel(keys,estimated_pose, 0.01, 20000.0)\n", - " current_score = score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", - " scores = score_pose_parallel_jit(potential_poses, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[:,test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " potential_poses = gaussian_vmf_parallel(keys, estimated_pose, 0.01, 20000.0)\n", + " current_score = score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + " scores = score_pose_parallel_jit(\n", + " potential_poses, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[:, test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", " if scores.max() > current_score:\n", " estimated_pose = potential_poses[scores.argmax()]\n", " keys = split_jit(keys[0], 100)\n", " print(scores.max(), current_score)\n", " b.show_trimesh(\"mesh\", b.RENDERER.meshes[0])\n", " b.set_pose(\"mesh\", estimated_pose)\n", - " b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + " b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { @@ -689,8 +767,16 @@ "metadata": {}, "outputs": [], "source": [ - "print(score_pose(random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))\n", - "print(score_pose(estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings)[test_rgbd_scaled.depth != jnp.inf].mean(-1))" + "print(\n", + " score_pose(\n", + " random_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")\n", + "print(\n", + " score_pose(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings, observed_embeddings\n", + " )[test_rgbd_scaled.depth != jnp.inf].mean(-1)\n", + ")" ] }, { @@ -711,21 +797,22 @@ "observed_embedding_colors = get_colors(observed_embeddings, proj_V)\n", "observed_embeddings_image_viz = b.get_rgb_image(observed_embedding_colors * 255.0)\n", "\n", - "pc_img, rendered_embedding_image = render_embedding_image(estimated_pose, keypoint_coordinates, keypoint_embeddings)\n", + "pc_img, rendered_embedding_image = render_embedding_image(\n", + " estimated_pose, keypoint_coordinates, keypoint_embeddings\n", + ")\n", "colors = get_colors(rendered_embedding_image, proj_V)\n", "rgba = jnp.array(b.get_rgb_image(colors * 255.0))\n", "# rgba = rgba.at[pc_img[:,:,2] > intrinsics.far - 0.01, :3].set(255.0)\n", "rerendered_embeddings_viz = b.get_rgb_image(rgba)\n", "\n", - "b.multi_panel([\n", - " b.get_rgb_image(test_rgbd.rgb), \n", - " b.scale_image(observed_embeddings_image_viz, 14.0),\n", - " b.scale_image(rerendered_embeddings_viz, 14.0)\n", - "],labels=[\n", - " \"Observed RGB\",\n", - " \"Embeddings\",\n", - " \"Reconstruction\"\n", - "],label_fontsize=50\n", + "b.multi_panel(\n", + " [\n", + " b.get_rgb_image(test_rgbd.rgb),\n", + " b.scale_image(observed_embeddings_image_viz, 14.0),\n", + " b.scale_image(rerendered_embeddings_viz, 14.0),\n", + " ],\n", + " labels=[\"Observed RGB\", \"Embeddings\", \"Reconstruction\"],\n", + " label_fontsize=50,\n", ").convert(\"RGB\")" ] }, @@ -738,7 +825,7 @@ "source": [ "# b.clear()\n", "# b.show_trimesh(\"mesh\", b.RENDERER.meshes[obj_idx])\n", - "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1,3))\n" + "b.show_cloud(\"obs\", observed_point_cloud_image.reshape(-1, 3))" ] }, { diff --git a/scripts/experiments/deeplearning/duduo.ipynb b/scripts/experiments/deeplearning/duduo.ipynb index 9fbe5ff8..d69375fc 100644 --- a/scripts/experiments/deeplearning/duduo.ipynb +++ b/scripts/experiments/deeplearning/duduo.ipynb @@ -23,7 +23,7 @@ "from PIL import Image\n", "from transformers import AutoModel\n", "import os\n", - "import bayes3d as b\n" + "import bayes3d as b" ] }, { @@ -33,11 +33,11 @@ "outputs": [], "source": [ "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('49', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\"49\", \"1\", bop_ycb_dir)\n", "frame_src = b.get_rgb_image(rgbd.rgb)\n", "\n", "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('51', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\"51\", \"1\", bop_ycb_dir)\n", "frame_dst = b.get_rgb_image(rgbd.rgb)" ] }, @@ -86,7 +86,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = AutoModel.from_pretrained(\"stevetod/doduo\", trust_remote_code=True)\n" + "model = AutoModel.from_pretrained(\"stevetod/doduo\", trust_remote_code=True)" ] }, { @@ -157,7 +157,9 @@ " y = int(event.ydata)\n", " frame1_img_mark = cv2.circle(frame1_img.copy(), (x, y), 3, (0, 0, 255), -1)\n", " max_x, max_y = matching[:, y, x]\n", - " frame2_img_mark = cv2.circle(frame2_img.copy(), (max_x, max_y), 3, (0, 255, 0), -1)\n", + " frame2_img_mark = cv2.circle(\n", + " frame2_img.copy(), (max_x, max_y), 3, (0, 255, 0), -1\n", + " )\n", " axs[0].imshow(frame1_img_mark)\n", " axs[0].axis(\"off\")\n", " axs[1].imshow(frame2_img_mark)\n", diff --git a/scripts/experiments/deeplearning/feature_detection/feature_detector.ipynb b/scripts/experiments/deeplearning/feature_detection/feature_detector.ipynb index 8d0a88e9..f91690fe 100644 --- a/scripts/experiments/deeplearning/feature_detection/feature_detector.ipynb +++ b/scripts/experiments/deeplearning/feature_detection/feature_detector.ipynb @@ -15,6 +15,7 @@ "from functools import partial\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", + "\n", "# import bayes3d.genjax\n", "# import genjax\n", "import pathlib\n", @@ -62,22 +63,23 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=200.0, fy=200.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.0001, far=20.0\n", + " height=100, width=100, fx=200.0, fy=200.0, cx=50.0, cy=50.0, near=0.0001, far=20.0\n", ")\n", "\n", "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/10.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -89,7 +91,7 @@ "IDX = 15\n", "table_pose = b.t3d.inverse_pose(\n", " b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.8, .15]),\n", + " jnp.array([0.0, 0.8, 0.15]),\n", " jnp.array([0.0, 0.0, 0.0]),\n", " jnp.array([0.0, 0.0, 1.0]),\n", " )\n", @@ -137,19 +139,22 @@ ], "source": [ "%matplotlib widget\n", - "img = b.RENDERER.render(table_pose[None,...], jnp.array([IDX]))\n", + "img = b.RENDERER.render(table_pose[None, ...], jnp.array([IDX]))\n", "fig, axes = plt.subplots()\n", - "b.add_depth_image(axes,img[...,2])\n", + "b.add_depth_image(axes, img[..., 2])\n", "location = (37, 49)\n", + "\n", + "\n", "def onclick(event):\n", " global location\n", " x, y = int(np.round(event.xdata)), int(np.round(event.ydata))\n", - " location = (x,y)\n", + " location = (x, y)\n", " print(location)\n", " axes.axis(\"on\")\n", - " axes.scatter(x,y, c='r', marker=\"x\", s=200)\n", + " axes.scatter(x, y, c=\"r\", marker=\"x\", s=200)\n", "\n", - "fig.canvas.mpl_connect('button_press_event', onclick)\n", + "\n", + "fig.canvas.mpl_connect(\"button_press_event\", onclick)\n", "plt.show()" ] }, @@ -159,33 +164,49 @@ "metadata": {}, "outputs": [], "source": [ - "coordinate = img[location[1],location[0],:3]\n", - "coordinate_in_object_frame = b.apply_transform(coordinate.reshape(-1,3), b.t3d.inverse_pose(table_pose))[0]\n", + "coordinate = img[location[1], location[0], :3]\n", + "coordinate_in_object_frame = b.apply_transform(\n", + " coordinate.reshape(-1, 3), b.t3d.inverse_pose(table_pose)\n", + ")[0]\n", "\n", "\n", "random_poses = jax.vmap(\n", - " b.distributions.gaussian_vmf_jit, in_axes=(0,None, None, None))(\n", - " jax.random.split(jax.random.PRNGKey(10),4000), table_pose, 0.05, 1.0)\n", - "coordinates = jnp.einsum(\"j,kij->ki\", b.add_homogenous_ones(coordinate_in_object_frame), random_poses)\n", - "pixel_coordinates = jnp.round(b.camera.project_cloud_to_pixels(coordinates, intrinsics)).astype(jnp.int32)\n", - "\n", - "images = b.RENDERER.render_many(random_poses[:,None,...], jnp.array([IDX]))\n", - "rendered_points = images[jnp.arange(images.shape[0]), pixel_coordinates[:,1], pixel_coordinates[:,0], :]\n", - "distances = jnp.linalg.norm(rendered_points[...,:3]- coordinates[:,:3],axis=-1)\n", + " b.distributions.gaussian_vmf_jit, in_axes=(0, None, None, None)\n", + ")(jax.random.split(jax.random.PRNGKey(10), 4000), table_pose, 0.05, 1.0)\n", + "coordinates = jnp.einsum(\n", + " \"j,kij->ki\", b.add_homogenous_ones(coordinate_in_object_frame), random_poses\n", + ")\n", + "pixel_coordinates = jnp.round(\n", + " b.camera.project_cloud_to_pixels(coordinates, intrinsics)\n", + ").astype(jnp.int32)\n", + "\n", + "images = b.RENDERER.render_many(random_poses[:, None, ...], jnp.array([IDX]))\n", + "rendered_points = images[\n", + " jnp.arange(images.shape[0]), pixel_coordinates[:, 1], pixel_coordinates[:, 0], :\n", + "]\n", + "distances = jnp.linalg.norm(rendered_points[..., :3] - coordinates[:, :3], axis=-1)\n", "valid_indices = distances < 0.005\n", "images_subset = images[valid_indices]\n", "pixel_coordinates_subset = pixel_coordinates[valid_indices]\n", "\n", "\n", "filter_size = 6\n", - "get_patch = lambda image, coordinate: jax.lax.dynamic_slice(image,\n", - " (coordinate[1]-filter_size, coordinate[0]-filter_size, 0),\n", - " (2*filter_size+1,2*filter_size+1,4))\n", + "get_patch = lambda image, coordinate: jax.lax.dynamic_slice(\n", + " image,\n", + " (coordinate[1] - filter_size, coordinate[0] - filter_size, 0),\n", + " (2 * filter_size + 1, 2 * filter_size + 1, 4),\n", + ")\n", "\n", "patches = jax.vmap(get_patch)(images_subset, pixel_coordinates_subset)\n", - "patches_centered = jnp.concatenate([patches[...,:3] - patches[:,filter_size,filter_size,:][...,None,None,:3],patches[...,3][...,None]],axis=-1)\n", + "patches_centered = jnp.concatenate(\n", + " [\n", + " patches[..., :3] - patches[:, filter_size, filter_size, :][..., None, None, :3],\n", + " patches[..., 3][..., None],\n", + " ],\n", + " axis=-1,\n", + ")\n", "\n", - "valid_indices = jnp.abs(patches_centered[:,:,2]).sum(-1).sum(-1) > 1e-4\n", + "valid_indices = jnp.abs(patches_centered[:, :, 2]).sum(-1).sum(-1) > 1e-4\n", "images_subset = images_subset[valid_indices]\n", "pixel_coordinates_subset = pixel_coordinates_subset[valid_indices]\n", "patches_centered = patches_centered[valid_indices]" @@ -229,8 +250,8 @@ } ], "source": [ - "images = [b.get_depth_image(img[...,2]) for img in patches_centered[:25]]\n", - "b.viz.scale_image(b.hvstack_images(images, 5, 5),4.0)" + "images = [b.get_depth_image(img[..., 2]) for img in patches_centered[:25]]\n", + "b.viz.scale_image(b.hvstack_images(images, 5, 5), 4.0)" ] }, { @@ -239,10 +260,16 @@ "metadata": {}, "outputs": [], "source": [ - "test_poses = jax.vmap(b.distributions.gaussian_vmf_jit, in_axes=(0,None, None, None))(jax.random.split(jax.random.PRNGKey(1000),200), table_pose, 0.05, 1.0)\n", - "test_images = b.RENDERER.render_many(test_poses[:,None,...], jnp.array([IDX]))\n", - "coordinates = jnp.einsum(\"j,kij->ki\", b.add_homogenous_ones(coordinate_in_object_frame), test_poses)\n", - "pixel_coordinates = jnp.round(b.camera.project_cloud_to_pixels(coordinates, intrinsics)).astype(jnp.int32)" + "test_poses = jax.vmap(b.distributions.gaussian_vmf_jit, in_axes=(0, None, None, None))(\n", + " jax.random.split(jax.random.PRNGKey(1000), 200), table_pose, 0.05, 1.0\n", + ")\n", + "test_images = b.RENDERER.render_many(test_poses[:, None, ...], jnp.array([IDX]))\n", + "coordinates = jnp.einsum(\n", + " \"j,kij->ki\", b.add_homogenous_ones(coordinate_in_object_frame), test_poses\n", + ")\n", + "pixel_coordinates = jnp.round(\n", + " b.camera.project_cloud_to_pixels(coordinates, intrinsics)\n", + ").astype(jnp.int32)" ] }, { @@ -252,51 +279,90 @@ "outputs": [], "source": [ "def get_error_between_patches(slice_centered, patch):\n", - "\n", " # far_mask_slice_centered = (slice_centered[...,2] > 10.0)[...,None]\n", " # far_mask_patch = (patch[...,2] > 10.0)[...,None]\n", - " \n", + "\n", " eps = 0.5\n", - " far_mask_slice_centered = ((1-slice_centered[...,3]) > eps)[...,None]\n", - " far_mask_patch = ((1-patch[...,3]) > eps)[...,None]\n", + " far_mask_slice_centered = ((1 - slice_centered[..., 3]) > eps)[..., None]\n", + " far_mask_patch = ((1 - patch[..., 3]) > eps)[..., None]\n", + "\n", + " slice_centered = slice_centered[..., :3]\n", + " patch = patch[..., :3]\n", "\n", - " slice_centered = slice_centered[...,:3]\n", - " patch = patch[...,:3]\n", - " \n", - " slice_centered = slice_centered * (1.0 - far_mask_slice_centered) + 1000.0 * far_mask_slice_centered\n", + " slice_centered = (\n", + " slice_centered * (1.0 - far_mask_slice_centered)\n", + " + 1000.0 * far_mask_slice_centered\n", + " )\n", " patch = patch * (1.0 - far_mask_patch) + 1000.0 * far_mask_patch\n", - " \n", + "\n", " distances = jnp.linalg.norm(slice_centered - patch, axis=-1)\n", " width = 0.005\n", - " probabilities_per_pixel = (distances > width/2)\n", + " probabilities_per_pixel = distances > width / 2\n", " return probabilities_per_pixel.sum()\n", "\n", "\n", - "get_error_between_patches_parallel_patches = jax.vmap(get_error_between_patches, in_axes=(None, 0))\n", + "get_error_between_patches_parallel_patches = jax.vmap(\n", + " get_error_between_patches, in_axes=(None, 0)\n", + ")\n", "\n", "\n", "@functools.partial(\n", " jnp.vectorize,\n", - " signature='(m)->()',\n", - " excluded=(1,2,),\n", + " signature=\"(m)->()\",\n", + " excluded=(\n", + " 1,\n", + " 2,\n", + " ),\n", ")\n", "def compute_error(ij, observed_xyz_padded, patch):\n", - " slice = jax.lax.dynamic_slice(observed_xyz_padded, (ij[0], ij[1], 0), (2*filter_size + 1, 2*filter_size + 1, 4))\n", + " slice = jax.lax.dynamic_slice(\n", + " observed_xyz_padded,\n", + " (ij[0], ij[1], 0),\n", + " (2 * filter_size + 1, 2 * filter_size + 1, 4),\n", + " )\n", " # don't center slice, move patch to center pixel of slice instead?\n", - " slice_centered = jnp.concatenate([slice[...,:3] - slice[filter_size, filter_size,:3], slice[...,3][...,None]], axis=-1)\n", - " \n", + " slice_centered = jnp.concatenate(\n", + " [\n", + " slice[..., :3] - slice[filter_size, filter_size, :3],\n", + " slice[..., 3][..., None],\n", + " ],\n", + " axis=-1,\n", + " )\n", + "\n", " return get_error_between_patches(slice_centered, patch)\n", - " #return get_error_between_patches(slice, patch)\n", + " # return get_error_between_patches(slice, patch)\n", "\n", "\n", "def get_errors(observed_xyz, template):\n", - " observed_xyz_padded = jax.lax.pad(observed_xyz, -100.0, ((filter_size,filter_size,0,),(filter_size,filter_size,0,),(0,0,0,)))\n", - " jj, ii = jnp.meshgrid(jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0]))\n", - " indices = jnp.stack([ii,jj],axis=-1)\n", + " observed_xyz_padded = jax.lax.pad(\n", + " observed_xyz,\n", + " -100.0,\n", + " (\n", + " (\n", + " filter_size,\n", + " filter_size,\n", + " 0,\n", + " ),\n", + " (\n", + " filter_size,\n", + " filter_size,\n", + " 0,\n", + " ),\n", + " (\n", + " 0,\n", + " 0,\n", + " 0,\n", + " ),\n", + " ),\n", + " )\n", + " jj, ii = jnp.meshgrid(\n", + " jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0])\n", + " )\n", + " indices = jnp.stack([ii, jj], axis=-1)\n", " heatmap = compute_error(indices, observed_xyz_padded, template)\n", " return heatmap\n", "\n", - " \n", + "\n", "get_errors_jit = jax.jit(get_errors)\n", "get_errors_vmap_jit = jax.jit(jax.vmap(get_errors, in_axes=(None, 0)))" ] @@ -309,47 +375,63 @@ "source": [ "import matplotlib.colors as mcolors\n", "\n", + "\n", "def fig2img(fig):\n", " \"\"\"Convert a Matplotlib figure to a PIL Image and return it\"\"\"\n", " import io\n", + "\n", " buf = io.BytesIO()\n", " fig.savefig(buf)\n", " buf.seek(0)\n", " img = PIL.Image.open(buf)\n", " return img\n", "\n", + "\n", "# generate three side by side activation visualizations\n", "# shape + activation, shape, and activation\n", "\n", + "\n", "def generate_heatmap_viz(observed_xyz, heatmaps):\n", " assert observed_xyz.shape[-1] == 4\n", " fig = plt.figure()\n", " ax = fig.add_subplot(1, 3, 3)\n", - " \n", - " original_image = observed_xyz[...,2]\n", - " ax.imshow(b.preprocess_for_viz(original_image[filter_size:-filter_size,filter_size:-filter_size])) # cut off edges for filter\n", - " \n", + "\n", + " original_image = observed_xyz[..., 2]\n", + " ax.imshow(\n", + " b.preprocess_for_viz(\n", + " original_image[filter_size:-filter_size, filter_size:-filter_size]\n", + " )\n", + " ) # cut off edges for filter\n", + "\n", " best_idx = jnp.unravel_index(heatmaps.argmin(), heatmaps.shape)\n", - " \n", - " c_white = mcolors.colorConverter.to_rgba('white',alpha = 0)\n", - " c_red= mcolors.colorConverter.to_rgba('red',alpha = 1)\n", - " cmap_rb = mcolors.LinearSegmentedColormap.from_list('rb_cmap',[c_red,c_white],512)\n", - " \n", + "\n", + " c_white = mcolors.colorConverter.to_rgba(\"white\", alpha=0)\n", + " c_red = mcolors.colorConverter.to_rgba(\"red\", alpha=1)\n", + " cmap_rb = mcolors.LinearSegmentedColormap.from_list(\n", + " \"rb_cmap\", [c_red, c_white], 512\n", + " )\n", + "\n", " hmap = heatmaps[best_idx[0]]\n", " ax.imshow(hmap, cmap_rb, alpha=0.75)\n", - " ax.axis('off')\n", + " ax.axis(\"off\")\n", + "\n", + " ax.scatter(\n", + " best_idx[2], best_idx[1], color=\"black\", marker=\"x\", alpha=1\n", + " ) # plot the pose point\n", + " # ax.axis('off')\n", "\n", - " ax.scatter(best_idx[2], best_idx[1],color=\"black\", marker='x', alpha=1) # plot the pose point \n", - " #ax.axis('off')\n", - " \n", " ax2 = fig.add_subplot(1, 3, 1)\n", - " ax2.axis('off')\n", - " ax2.imshow(b.preprocess_for_viz(original_image[filter_size:-filter_size,filter_size:-filter_size])) # cut off edges for filter\n", - " \n", + " ax2.axis(\"off\")\n", + " ax2.imshow(\n", + " b.preprocess_for_viz(\n", + " original_image[filter_size:-filter_size, filter_size:-filter_size]\n", + " )\n", + " ) # cut off edges for filter\n", + "\n", " ax3 = fig.add_subplot(1, 3, 2)\n", - " ax3.axis('off')\n", + " ax3.axis(\"off\")\n", " ax3.imshow(hmap, cmap_rb)\n", - " \n", + "\n", " img_PIL = fig2img(fig)\n", " return img_PIL" ] @@ -406,9 +488,15 @@ "source": [ "i = 25\n", "observed_xyz = test_images[i]\n", - "heatmaps = get_errors_vmap_jit(observed_xyz, patches_centered)[:, filter_size:-filter_size,filter_size:-filter_size]\n", - "obs_mask = (observed_xyz[...,3][filter_size:-filter_size,filter_size:-filter_size]).astype(bool)\n", - "clean_heatmap = (heatmaps*obs_mask) + (1-obs_mask)[None,...] * jnp.max(heatmaps, axis=(1,2))[...,None,None]\n", + "heatmaps = get_errors_vmap_jit(observed_xyz, patches_centered)[\n", + " :, filter_size:-filter_size, filter_size:-filter_size\n", + "]\n", + "obs_mask = (\n", + " observed_xyz[..., 3][filter_size:-filter_size, filter_size:-filter_size]\n", + ").astype(bool)\n", + "clean_heatmap = (heatmaps * obs_mask) + (1 - obs_mask)[None, ...] * jnp.max(\n", + " heatmaps, axis=(1, 2)\n", + ")[..., None, None]\n", "print(clean_heatmap.min())\n", "# visualize a single heatmap\n", "generate_heatmap_viz(observed_xyz, clean_heatmap)" @@ -468,9 +556,15 @@ "\n", "heatmap_list = []\n", "for observed_xyz in test_images:\n", - " heatmaps = get_errors_vmap_jit(observed_xyz, patches_centered)[:, filter_size:-filter_size,filter_size:-filter_size]\n", - " obs_mask = (observed_xyz[...,3][filter_size:-filter_size,filter_size:-filter_size]).astype(bool)\n", - " clean_heatmap = (heatmaps*obs_mask) + (1-obs_mask)[None,...] * jnp.max(heatmaps, axis=(1,2))[...,None,None]\n", + " heatmaps = get_errors_vmap_jit(observed_xyz, patches_centered)[\n", + " :, filter_size:-filter_size, filter_size:-filter_size\n", + " ]\n", + " obs_mask = (\n", + " observed_xyz[..., 3][filter_size:-filter_size, filter_size:-filter_size]\n", + " ).astype(bool)\n", + " clean_heatmap = (heatmaps * obs_mask) + (1 - obs_mask)[None, ...] * jnp.max(\n", + " heatmaps, axis=(1, 2)\n", + " )[..., None, None]\n", " heatmap_list.append(generate_heatmap_viz(observed_xyz, clean_heatmap))" ] }, @@ -482,9 +576,10 @@ "source": [ "# make heatmap gif from stack\n", "\n", + "\n", "def make_gif_from_pil_images(images, filename):\n", " \"\"\"Save a list of PIL images as a GIF.\n", - " \n", + "\n", " Args:\n", " images (list): List of PIL images.\n", " filename (str): Filename to save GIF to.\n", @@ -498,7 +593,8 @@ " loop=0,\n", " )\n", "\n", - "make_gif_from_pil_images(heatmap_list, 'heatmap_video_masked.gif')" + "\n", + "make_gif_from_pil_images(heatmap_list, \"heatmap_video_masked.gif\")" ] }, { diff --git a/scripts/experiments/deeplearning/kubric_dataset_gen/breaking_cosypose.ipynb b/scripts/experiments/deeplearning/kubric_dataset_gen/breaking_cosypose.ipynb index 602cd60d..4016a1be 100644 --- a/scripts/experiments/deeplearning/kubric_dataset_gen/breaking_cosypose.ipynb +++ b/scripts/experiments/deeplearning/kubric_dataset_gen/breaking_cosypose.ipynb @@ -30,16 +30,19 @@ "model_dir = os.path.join(j.utils.get_assets_dir(), \"ycb_video_models/models\")\n", "print(f\"{model_dir} exists: {os.path.exists(model_dir)}\")\n", "model_names = j.ycb_loader.MODEL_NAMES\n", - "model_paths = [os.path.join(model_dir,name,\"textured.obj\") for name in model_names]\n", + "model_paths = [os.path.join(model_dir, name, \"textured.obj\") for name in model_names]\n", "\n", "bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img(\"52\", \"1\", bop_ycb_dir)\n", "intrinsics = j.Intrinsics(\n", " height=rgbd.intrinsics.height,\n", " width=rgbd.intrinsics.width,\n", - " fx=rgbd.intrinsics.fx, fy=rgbd.intrinsics.fx,\n", - " cx=rgbd.intrinsics.width/2.0, cy=rgbd.intrinsics.height/2.0,\n", - " near=0.001, far=3.0\n", + " fx=rgbd.intrinsics.fx,\n", + " fy=rgbd.intrinsics.fx,\n", + " cx=rgbd.intrinsics.width / 2.0,\n", + " cy=rgbd.intrinsics.height / 2.0,\n", + " near=0.001,\n", + " far=3.0,\n", ")" ] }, @@ -55,13 +58,10 @@ "model_names = [\"obj_\" + f\"{str(idx+1).rjust(6, '0')}.ply\" for idx in range(21)]\n", "mesh_paths = []\n", "for name in model_names:\n", - " mesh_path = os.path.join(model_dir,name)\n", + " mesh_path = os.path.join(model_dir, name)\n", " mesh_paths.append(mesh_path)\n", - " model_scaling_factor = 1.0/1000.0\n", - " renderer.add_mesh_from_file(\n", - " mesh_path,\n", - " scaling_factor=model_scaling_factor\n", - " )" + " model_scaling_factor = 1.0 / 1000.0\n", + " renderer.add_mesh_from_file(mesh_path, scaling_factor=model_scaling_factor)" ] }, { @@ -90,7 +90,13 @@ "metadata": {}, "outputs": [], "source": [ - "all_data = j.kubric_interface.render_multiobject_parallel([model_paths[IDX],model_paths[IDX2]], object_poses[:,None,...], intrinsics, scaling_factor=1.0, lighting=3.0) # multi img singleobj" + "all_data = j.kubric_interface.render_multiobject_parallel(\n", + " [model_paths[IDX], model_paths[IDX2]],\n", + " object_poses[:, None, ...],\n", + " intrinsics,\n", + " scaling_factor=1.0,\n", + " lighting=3.0,\n", + ") # multi img singleobj" ] }, { @@ -111,12 +117,18 @@ "metadata": {}, "outputs": [], "source": [ - "# function taking the rbgd rgb and intrinstics as well as the renderer and returning the cosypose prediction \n", + "# function taking the rbgd rgb and intrinstics as well as the renderer and returning the cosypose prediction\n", "def cosypose_pred(rgb, intrinsics, renderer):\n", " pred = j.cosypose_utils.cosypose_interface(rgb, j.K_from_intrinsics(intrinsics))\n", - " pred_poses, pred_ids, pred_scores = pred['pred_poses'], pred['pred_ids'], pred['pred_scores']\n", - " rendered = renderer.render_multiobject(jnp.array(pred_poses[0]), jnp.array(pred_ids[0]))\n", - " return j.get_depth_image(rendered[:,:,2]) " + " pred_poses, pred_ids, pred_scores = (\n", + " pred[\"pred_poses\"],\n", + " pred[\"pred_ids\"],\n", + " pred[\"pred_scores\"],\n", + " )\n", + " rendered = renderer.render_multiobject(\n", + " jnp.array(pred_poses[0]), jnp.array(pred_ids[0])\n", + " )\n", + " return j.get_depth_image(rendered[:, :, 2])" ] }, { @@ -189,6 +201,7 @@ "\n", " return noisy_img\n", "\n", + "\n", "def make_low_resolution(img, scale_factor=0.5):\n", " \"\"\"\n", " Create a low-resolution version of an image by downsampling and upsampling.\n", @@ -201,10 +214,14 @@ " numpy.ndarray: The low-resolution image as a NumPy array.\n", " \"\"\"\n", " # Downsample the image\n", - " downsampled_img = cv2.resize(img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_AREA)\n", + " downsampled_img = cv2.resize(\n", + " img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_AREA\n", + " )\n", "\n", " # Upsample the image\n", - " low_res_img = cv2.resize(downsampled_img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)\n", + " low_res_img = cv2.resize(\n", + " downsampled_img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST\n", + " )\n", "\n", " return low_res_img" ] @@ -228,13 +245,15 @@ "poses_list = []\n", "for i in range(3):\n", " for k in range(3):\n", - " poses_list.append(object_pose @ j.t3d.transform_from_pos(jnp.array([0.1*i, 0.1*k, 0.0])))\n", + " poses_list.append(\n", + " object_pose @ j.t3d.transform_from_pos(jnp.array([0.1 * i, 0.1 * k, 0.0]))\n", + " )\n", "\n", "object_poses = jnp.array(poses_list)\n", "\n", - "# testing a variety of models \n", - "idx_list = [i for i in range(10,19)]\n", - "m_paths = [] \n", + "# testing a variety of models\n", + "idx_list = [i for i in range(10, 19)]\n", + "m_paths = []\n", "for idx in idx_list:\n", " m_paths.append(model_paths[idx])" ] @@ -246,7 +265,9 @@ "metadata": {}, "outputs": [], "source": [ - "all_data = j.kubric_interface.render_multiobject_parallel(m_paths, object_poses[:,None,...], intrinsics, scaling_factor=1.0, lighting=3.0) # multi img singleobj\n", + "all_data = j.kubric_interface.render_multiobject_parallel(\n", + " m_paths, object_poses[:, None, ...], intrinsics, scaling_factor=1.0, lighting=3.0\n", + ") # multi img singleobj\n", "rgbd = all_data[0]\n", "j.get_rgb_image(rgbd.rgb)" ] @@ -279,7 +300,7 @@ "metadata": {}, "outputs": [], "source": [ - "low_res = make_low_resolution(rgbd.rgb, scale_factor=.5)\n", + "low_res = make_low_resolution(rgbd.rgb, scale_factor=0.5)\n", "j.get_rgb_image(low_res)" ] }, @@ -300,7 +321,7 @@ "metadata": {}, "outputs": [], "source": [ - "gauss_low_res = make_low_resolution(gauss, scale_factor=.25)\n", + "gauss_low_res = make_low_resolution(gauss, scale_factor=0.25)\n", "low_res_gauss = add_gaussian_noise(low_res, variance=300)\n", "j.get_rgb_image(gauss_low_res)" ] @@ -332,10 +353,16 @@ "metadata": {}, "outputs": [], "source": [ - "pred = j.cosypose_utils.cosypose_interface(low_res_gauss, j.K_from_intrinsics(rgbd.intrinsics))\n", - "pred_poses, pred_ids, pred_scores = pred['pred_poses'], pred['pred_ids'], pred['pred_scores']\n", + "pred = j.cosypose_utils.cosypose_interface(\n", + " low_res_gauss, j.K_from_intrinsics(rgbd.intrinsics)\n", + ")\n", + "pred_poses, pred_ids, pred_scores = (\n", + " pred[\"pred_poses\"],\n", + " pred[\"pred_ids\"],\n", + " pred[\"pred_scores\"],\n", + ")\n", "rendered = renderer.render_multiobject(jnp.array(pred_poses[0]), jnp.array(pred_ids[0]))\n", - "j.get_depth_image(rendered[:,:,2])" + "j.get_depth_image(rendered[:, :, 2])" ] }, { @@ -367,25 +394,28 @@ "model_dir = os.path.join(j.utils.get_assets_dir(), \"ycb_video_models/models\")\n", "print(f\"{model_dir} exists: {os.path.exists(model_dir)}\")\n", "model_names = j.ycb_loader.MODEL_NAMES\n", - "model_paths = [os.path.join(model_dir,name,\"textured.obj\") for name in model_names]\n", + "model_paths = [os.path.join(model_dir, name, \"textured.obj\") for name in model_names]\n", "\n", "bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img(\"52\", \"1\", bop_ycb_dir)\n", "intrinsics = j.Intrinsics(\n", " height=rgbd.intrinsics.height,\n", " width=rgbd.intrinsics.width,\n", - " fx=rgbd.intrinsics.fx, fy=rgbd.intrinsics.fx,\n", - " cx=rgbd.intrinsics.width/2.0, cy=rgbd.intrinsics.height/2.0,\n", - " near=0.001, far=3.0\n", + " fx=rgbd.intrinsics.fx,\n", + " fy=rgbd.intrinsics.fx,\n", + " cx=rgbd.intrinsics.width / 2.0,\n", + " cy=rgbd.intrinsics.height / 2.0,\n", + " near=0.001,\n", + " far=3.0,\n", ")\n", "renderer = j.Renderer(rgbd.intrinsics, num_layers=25)\n", "model_dir = os.path.join(j.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "model_names = [\"obj_\" + f\"{str(idx+1).rjust(6, '0')}.ply\" for idx in range(21)]\n", "mesh_paths = []\n", "for name in model_names:\n", - " mesh_path = os.path.join(model_dir,name)\n", + " mesh_path = os.path.join(model_dir, name)\n", " mesh_paths.append(mesh_path)\n", - " model_scaling_factor = 1.0/1000.0" + " model_scaling_factor = 1.0 / 1000.0" ] }, { @@ -395,7 +425,7 @@ "metadata": {}, "outputs": [], "source": [ - "#poses\n", + "# poses\n", "camera_pose = j.t3d.transform_from_pos_target_up(\n", " jnp.array([0.5, 0.5, 0.5]),\n", " jnp.array([0.0, 0.0, 0.0]),\n", @@ -408,16 +438,18 @@ "poses_list = []\n", "for i in range(3):\n", " for k in range(3):\n", - " poses_list.append(object_pose @ j.t3d.transform_from_pos(jnp.array([0.1*i, 0.1*k, 0.0])))\n", + " poses_list.append(\n", + " object_pose @ j.t3d.transform_from_pos(jnp.array([0.1 * i, 0.1 * k, 0.0]))\n", + " )\n", "\n", "object_poses = jnp.array(poses_list)\n", "\n", - "# model paths \n", - "# a list from one to nine \n", - "idx_list = [i for i in range(10,19)]\n", + "# model paths\n", + "# a list from one to nine\n", + "idx_list = [i for i in range(10, 19)]\n", "\n", - "#add model paths to list based on idx_list\n", - "m_paths = [] \n", + "# add model paths to list based on idx_list\n", + "m_paths = []\n", "for idx in idx_list:\n", " m_paths.append(model_paths[idx])" ] @@ -429,7 +461,9 @@ "metadata": {}, "outputs": [], "source": [ - "all_data = j.kubric_interface.render_multiobject_parallel(m_paths, object_poses[:,None,...], intrinsics, scaling_factor=1.0, lighting=3.0) # multi img singleobj\n", + "all_data = j.kubric_interface.render_multiobject_parallel(\n", + " m_paths, object_poses[:, None, ...], intrinsics, scaling_factor=1.0, lighting=3.0\n", + ") # multi img singleobj\n", "rgbd = all_data[0]\n", "j.get_rgb_image(rgbd.rgb)" ] @@ -441,8 +475,14 @@ "metadata": {}, "outputs": [], "source": [ - "pred = j.cosypose_utils.cosypose_interface(np.array(rgbd.rgb), j.K_from_intrinsics(rgbd.intrinsics))\n", - "pred_poses, pred_ids, pred_scores = pred['pred_poses'], pred['pred_ids'], pred['pred_scores']" + "pred = j.cosypose_utils.cosypose_interface(\n", + " np.array(rgbd.rgb), j.K_from_intrinsics(rgbd.intrinsics)\n", + ")\n", + "pred_poses, pred_ids, pred_scores = (\n", + " pred[\"pred_poses\"],\n", + " pred[\"pred_ids\"],\n", + " pred[\"pred_scores\"],\n", + ")" ] }, { @@ -453,7 +493,7 @@ "outputs": [], "source": [ "rendered = renderer.render_multiobject(jnp.array(pred_poses[0]), jnp.array(pred_ids[0]))\n", - "j.get_depth_image(rendered[:,:,2])" + "j.get_depth_image(rendered[:, :, 2])" ] } ], diff --git a/scripts/experiments/deeplearning/kubric_dataset_gen/densefusion_test.ipynb b/scripts/experiments/deeplearning/kubric_dataset_gen/densefusion_test.ipynb index 95c68e2c..6213a240 100644 --- a/scripts/experiments/deeplearning/kubric_dataset_gen/densefusion_test.ipynb +++ b/scripts/experiments/deeplearning/kubric_dataset_gen/densefusion_test.ipynb @@ -26,13 +26,16 @@ "\n", "\n", "bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img(\"52\", \"1\", bop_ycb_dir)\n", "intrinsics = j.Intrinsics(\n", " height=rgbd.intrinsics.height,\n", " width=rgbd.intrinsics.width,\n", - " fx=rgbd.intrinsics.fx, fy=rgbd.intrinsics.fx,\n", - " cx=rgbd.intrinsics.width/2.0, cy=rgbd.intrinsics.height/2.0,\n", - " near=0.001, far=2.0\n", + " fx=rgbd.intrinsics.fx,\n", + " fy=rgbd.intrinsics.fx,\n", + " cx=rgbd.intrinsics.width / 2.0,\n", + " cy=rgbd.intrinsics.height / 2.0,\n", + " near=0.001,\n", + " far=2.0,\n", ")" ] }, @@ -44,6 +47,7 @@ "outputs": [], "source": [ "import bayes3d.posecnn_densefusion\n", + "\n", "densefusion = jax3dp3.posecnn_densefusion.DenseFusion()" ] }, @@ -54,7 +58,7 @@ "metadata": {}, "outputs": [], "source": [ - "mesh_path = os.path.join(model_dir,name,\"textured.obj\")\n", + "mesh_path = os.path.join(model_dir, name, \"textured.obj\")\n", "print(mesh_path)\n", "mesh = j.mesh.load_mesh(mesh_path)" ] @@ -69,11 +73,11 @@ "NUM_IMAGES_PER_ITER = 5\n", "FIXED_TRANSLATION = jnp.array([0.0, 0.08324493, 1.0084537])\n", "_seed = 1222\n", - "key = jax.random.PRNGKey(_seed) \n", + "key = jax.random.PRNGKey(_seed)\n", "object_poses = jax.vmap(lambda key: j.distributions.gaussian_vmf(key, 0.00001, 0.001))(\n", " jax.random.split(key, NUM_IMAGES_PER_ITER)\n", ")\n", - "object_poses = object_poses.at[:,:3,3].set(FIXED_TRANSLATION)" + "object_poses = object_poses.at[:, :3, 3].set(FIXED_TRANSLATION)" ] }, { @@ -91,20 +95,33 @@ "\n", "# generate and save the dataset\n", "if not load_from_existing:\n", - " rgbds = j.kubric_interface.render_multiobject_parallel([mesh_path], object_poses[None,...],\n", - " intrinsics, scaling_factor=1.0, lighting=1.0) # multi img singleobj\n", - " np.savez(DATASET_FILE, rgbds=rgbds, poses=object_poses, id=IDX, name=model_names[IDX], intrinsics=intrinsics, mesh_path=mesh_path)\n", + " rgbds = j.kubric_interface.render_multiobject_parallel(\n", + " [mesh_path],\n", + " object_poses[None, ...],\n", + " intrinsics,\n", + " scaling_factor=1.0,\n", + " lighting=1.0,\n", + " ) # multi img singleobj\n", + " np.savez(\n", + " DATASET_FILE,\n", + " rgbds=rgbds,\n", + " poses=object_poses,\n", + " id=IDX,\n", + " name=model_names[IDX],\n", + " intrinsics=intrinsics,\n", + " mesh_path=mesh_path,\n", + " )\n", "\n", "# or load preexisting dataset\n", "else:\n", - " data = np.load(DATASET_FILE,allow_pickle=True)\n", + " data = np.load(DATASET_FILE, allow_pickle=True)\n", " rgbds = data[\"rgbds\"]\n", " object_poses = data[\"poses\"]\n", " id = data[\"id\"].item()\n", "\n", "rgb_images = j.hvstack_images([j.get_rgb_image(r.rgb) for r in rgbds], 1, 5)\n", "rgb_images.show()\n", - "rgb_images.save(f\"dataset_{NUM_IMAGES_PER_ITER}_seed_{_seed}.png\")\n" + "rgb_images.save(f\"dataset_{NUM_IMAGES_PER_ITER}_seed_{_seed}.png\")" ] }, { @@ -117,7 +134,9 @@ "# run densefusion on dataset\n", "all_results = []\n", "for scene_idx, rgbd in enumerate(rgbds):\n", - " results = densefusion.get_densefusion_results(rgbd.rgb, rgbd.depth, rgbd.intrinsics, scene_name=str(scene_idx))\n", + " results = densefusion.get_densefusion_results(\n", + " rgbd.rgb, rgbd.depth, rgbd.intrinsics, scene_name=str(scene_idx)\n", + " )\n", " all_results.extend(results)" ] }, @@ -130,13 +149,14 @@ "source": [ "# process densefusion results\n", "import pickle\n", - "with open(f\"dataset_{NUM_IMAGES_PER_ITER}_seed_{_seed}.pkl\", 'wb') as f:\n", + "\n", + "with open(f\"dataset_{NUM_IMAGES_PER_ITER}_seed_{_seed}.pkl\", \"wb\") as f:\n", " pickle.dump(all_results, f)\n", "\n", - "translation_err = jnp.zeros((1,3))\n", + "translation_err = jnp.zeros((1, 3))\n", "for results in all_results:\n", - " pred_rot = results[name]['rot_q']\n", - " pred_transl = results[name]['tr']\n", + " pred_rot = results[name][\"rot_q\"]\n", + " pred_transl = results[name][\"tr\"]\n", " translation_err += pred_transl\n", "\n", "avg_translation_err = translation_err / len(all_results)\n", diff --git a/scripts/experiments/deeplearning/kubric_dataset_gen/get dense fusion to work.ipynb b/scripts/experiments/deeplearning/kubric_dataset_gen/get dense fusion to work.ipynb index da3b0174..8beb1cf9 100644 --- a/scripts/experiments/deeplearning/kubric_dataset_gen/get dense fusion to work.ipynb +++ b/scripts/experiments/deeplearning/kubric_dataset_gen/get dense fusion to work.ipynb @@ -8,7 +8,10 @@ "outputs": [], "source": [ "import sys\n", - "sys.path.insert(0,'/home/nishadgothoskar/jax3dp3/jax3dp3/posecnn-pytorch/PoseCNN-PyTorch/')\n", + "\n", + "sys.path.insert(\n", + " 0, \"/home/nishadgothoskar/jax3dp3/jax3dp3/posecnn-pytorch/PoseCNN-PyTorch/\"\n", + ")\n", "import jax.numpy as jnp\n", "import bayes3d as b\n", "import trimesh\n", @@ -28,13 +31,16 @@ "\n", "\n", "bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img(\"52\", \"1\", bop_ycb_dir)\n", "intrinsics = j.Intrinsics(\n", " height=rgbd.intrinsics.height,\n", " width=rgbd.intrinsics.width,\n", - " fx=rgbd.intrinsics.fx, fy=rgbd.intrinsics.fx,\n", - " cx=rgbd.intrinsics.width/2.0, cy=rgbd.intrinsics.height/2.0,\n", - " near=0.001, far=1.0\n", + " fx=rgbd.intrinsics.fx,\n", + " fy=rgbd.intrinsics.fx,\n", + " cx=rgbd.intrinsics.width / 2.0,\n", + " cy=rgbd.intrinsics.height / 2.0,\n", + " near=0.001,\n", + " far=1.0,\n", ")" ] }, @@ -46,7 +52,8 @@ "outputs": [], "source": [ "import bayes3d.posecnn_densefusion\n", - "densefusion = jax3dp3.posecnn_densefusion.DenseFusion()\n" + "\n", + "densefusion = jax3dp3.posecnn_densefusion.DenseFusion()" ] }, { @@ -56,8 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "mesh_path = os.path.join(model_dir,name,\"textured.obj\")\n", + "mesh_path = os.path.join(model_dir, name, \"textured.obj\")\n", "print(mesh_path)\n", "mesh = j.mesh.load_mesh(mesh_path)" ] @@ -76,7 +82,7 @@ ")\n", "# object_poses = jnp.array([jnp.eye(4)])\n", "# object_poses = object_poses.at[:,:3,3].set(jnp.array([0.0, 0.0, 1.0]))\n", - "object_poses = object_poses.at[:,:3,3].set(jnp.array([0.083, 0.08324493, 1.0084537 ]))" + "object_poses = object_poses.at[:, :3, 3].set(jnp.array([0.083, 0.08324493, 1.0084537]))" ] }, { @@ -86,8 +92,9 @@ "metadata": {}, "outputs": [], "source": [ - "all_data = j.kubric_interface.render_multiobject_parallel([mesh_path], object_poses[None,...],\n", - " intrinsics, scaling_factor=1.0, lighting=2.0) # multi img singleobj\n", + "all_data = j.kubric_interface.render_multiobject_parallel(\n", + " [mesh_path], object_poses[None, ...], intrinsics, scaling_factor=1.0, lighting=2.0\n", + ") # multi img singleobj\n", "\n", "rgbd = all_data[0]" ] @@ -111,7 +118,9 @@ "metadata": {}, "outputs": [], "source": [ - "results = densefusion.get_densefusion_results(rgb, rgbd.depth, rgbd.intrinsics, scene_name=\"1\")" + "results = densefusion.get_densefusion_results(\n", + " rgb, rgbd.depth, rgbd.intrinsics, scene_name=\"1\"\n", + ")" ] }, { diff --git a/scripts/experiments/deeplearning/sam/fastsam.ipynb b/scripts/experiments/deeplearning/sam/fastsam.ipynb index e221f236..4f529fbd 100644 --- a/scripts/experiments/deeplearning/sam/fastsam.ipynb +++ b/scripts/experiments/deeplearning/sam/fastsam.ipynb @@ -11,7 +11,8 @@ "import os\n", "import numpy as np\n", "import sys\n", - "sys.path.append('/home/nishadgothoskar/FastSAM/')" + "\n", + "sys.path.append(\"/home/nishadgothoskar/FastSAM/\")" ] }, { @@ -21,7 +22,7 @@ "outputs": [], "source": [ "import argparse\n", - "from fastsam import FastSAM, FastSAMPrompt \n", + "from fastsam import FastSAM, FastSAMPrompt\n", "import ast\n", "import torch\n", "from PIL import Image\n", @@ -59,7 +60,7 @@ "outputs": [], "source": [ "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('53', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\"53\", \"1\", bop_ycb_dir)\n", "input = b.get_rgb_image(rgbd.rgb).convert(\"RGB\")\n", "input" ] @@ -70,18 +71,15 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "everything_results = model(\n", - "input,\n", - "device=device,\n", - "retina_masks=True,\n", - "imgsz=1024,\n", - "conf=0.4,\n", - "iou=0.95 \n", + " input, device=device, retina_masks=True, imgsz=1024, conf=0.4, iou=0.95\n", ")\n", "prompt_process = FastSAMPrompt(input, everything_results, device=device)\n", "ann = prompt_process.everything_prompt()\n", - "prompt_process.plot(annotations=ann,output_path='output.jpg',)\n" + "prompt_process.plot(\n", + " annotations=ann,\n", + " output_path=\"output.jpg\",\n", + ")" ] }, { diff --git a/scripts/experiments/deeplearning/sam/sam.ipynb b/scripts/experiments/deeplearning/sam/sam.ipynb index a7ce4bcc..5c633c4f 100644 --- a/scripts/experiments/deeplearning/sam/sam.ipynb +++ b/scripts/experiments/deeplearning/sam/sam.ipynb @@ -16,23 +16,28 @@ "import warnings\n", "from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, build_sam\n", "import sys\n", + "\n", "sys.path.extend([\"/home/nishadgothoskar/ptamp/pybullet_planning\"])\n", "sys.path.extend([\"/home/nishadgothoskar/ptamp\"])\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "\n", "bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir)\n", + "rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img(\"52\", \"1\", bop_ycb_dir)\n", "\n", - "test_pkl_file = os.path.join(j.utils.get_assets_dir(),\"sample_imgs/strawberry_error.pkl\")\n", - "test_pkl_file = os.path.join(j.utils.get_assets_dir(),\"sample_imgs/knife_spoon_box_real.pkl\")\n", - "test_pkl_file = os.path.join(j.utils.get_assets_dir(),\"sample_imgs/red_lego_multi.pkl\")\n", - "test_pkl_file = os.path.join(j.utils.get_assets_dir(),\"sample_imgs/demo2_nolight.pkl\")\n", + "test_pkl_file = os.path.join(\n", + " j.utils.get_assets_dir(), \"sample_imgs/strawberry_error.pkl\"\n", + ")\n", + "test_pkl_file = os.path.join(\n", + " j.utils.get_assets_dir(), \"sample_imgs/knife_spoon_box_real.pkl\"\n", + ")\n", + "test_pkl_file = os.path.join(j.utils.get_assets_dir(), \"sample_imgs/red_lego_multi.pkl\")\n", + "test_pkl_file = os.path.join(j.utils.get_assets_dir(), \"sample_imgs/demo2_nolight.pkl\")\n", "\n", - "file = open(test_pkl_file,'rb')\n", + "file = open(test_pkl_file, \"rb\")\n", "camera_images = pickle.load(file)[\"camera_images\"]\n", "images = [j.RGBD.construct_from_camera_image(c) for c in camera_images]\n", - "rgbd = images[0]\n" + "rgbd = images[0]" ] }, { @@ -42,7 +47,7 @@ "metadata": {}, "outputs": [], "source": [ - "j.get_rgb_image(rgbd.rgb).save(\"rgb.png\")\n" + "j.get_rgb_image(rgbd.rgb).save(\"rgb.png\")" ] }, { @@ -52,9 +57,10 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "sam = build_sam(checkpoint=\"/home/nishadgothoskar/jax3dp3/assets/sam/sam_vit_h_4b8939.pth\")\n", - "sam.to(device=\"cuda\")\n" + "sam = build_sam(\n", + " checkpoint=\"/home/nishadgothoskar/jax3dp3/assets/sam/sam_vit_h_4b8939.pth\"\n", + ")\n", + "sam.to(device=\"cuda\")" ] }, { @@ -66,12 +72,12 @@ "source": [ "mask_generator = SamAutomaticMaskGenerator(\n", " model=sam,\n", - "# points_per_side=32,\n", - "# pred_iou_thresh=0.90,\n", - "# stability_score_thresh=0.95,\n", - "# crop_n_layers=0,\n", - "# crop_n_points_downscale_factor=1,\n", - "# min_mask_region_area=200, # Requires open-cv to run post-processing\n", + " # points_per_side=32,\n", + " # pred_iou_thresh=0.90,\n", + " # stability_score_thresh=0.95,\n", + " # crop_n_layers=0,\n", + " # crop_n_points_downscale_factor=1,\n", + " # min_mask_region_area=200, # Requires open-cv to run post-processing\n", ")" ] }, @@ -83,7 +89,7 @@ "outputs": [], "source": [ "image = np.array(rgbd.rgb)\n", - "masks = mask_generator.generate(image)\n" + "masks = mask_generator.generate(image)" ] }, { @@ -96,24 +102,27 @@ "def show_anns(anns):\n", " if len(anns) == 0:\n", " return\n", - " sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)\n", + " sorted_anns = sorted(anns, key=(lambda x: x[\"area\"]), reverse=True)\n", " ax = plt.gca()\n", " ax.set_autoscale_on(False)\n", " polygons = []\n", " color = []\n", " for ann in sorted_anns:\n", - " m = ann['segmentation']\n", + " m = ann[\"segmentation\"]\n", " img = np.ones((m.shape[0], m.shape[1], 3))\n", " color_mask = np.random.random((1, 3)).tolist()[0]\n", " for i in range(3):\n", - " img[:,:,i] = color_mask[i]\n", - " ax.imshow(np.dstack((img, m*0.35)))\n", + " img[:, :, i] = color_mask[i]\n", + " ax.imshow(np.dstack((img, m * 0.35)))\n", + "\n", + "\n", "import matplotlib.pyplot as plt\n", - "plt.figure(figsize=(20,20))\n", + "\n", + "plt.figure(figsize=(20, 20))\n", "plt.imshow(image)\n", "show_anns(masks)\n", - "plt.axis('off')\n", - "plt.show() " + "plt.axis(\"off\")\n", + "plt.show()" ] }, { @@ -131,17 +140,19 @@ "\n", " matched = False\n", " for jj in range(num_objects_so_far):\n", - " seg_mask_existing_object = (full_segmentation == jj)\n", - " \n", + " seg_mask_existing_object = full_segmentation == jj\n", + "\n", " intersection = seg_mask * seg_mask_existing_object\n", " if intersection[seg_mask].mean() > 0.9:\n", " matched = True\n", - " \n", + "\n", " if not matched:\n", " full_segmentation = full_segmentation.at[seg_mask].set(num_objects_so_far)\n", " num_objects_so_far += 1\n", "\n", - " segmentation_image = j.get_depth_image(full_segmentation + 1,max=full_segmentation.max() + 2)\n", + " segmentation_image = j.get_depth_image(\n", + " full_segmentation + 1, max=full_segmentation.max() + 2\n", + " )\n", " seg_viz = j.get_depth_image(seg_mask)\n", " viz_images.append(j.hstack_images([segmentation_image, seg_viz]))" ] diff --git a/scripts/experiments/gaussian_splatting/3dgs_from_rgb_video.ipynb b/scripts/experiments/gaussian_splatting/3dgs_from_rgb_video.ipynb index 219a358b..f828487f 100644 --- a/scripts/experiments/gaussian_splatting/3dgs_from_rgb_video.ipynb +++ b/scripts/experiments/gaussian_splatting/3dgs_from_rgb_video.ipynb @@ -51,11 +51,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=200.0, fy=200.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.001, far=6.0\n", + " height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.001, far=6.0\n", ")" ] }, @@ -76,7 +72,6 @@ } ], "source": [ - "\n", "visualizer = Open3DVisualizer(intrinsics)" ] }, @@ -87,10 +82,11 @@ "outputs": [], "source": [ "import os\n", + "\n", "model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", "i = 4\n", "mesh_path = os.path.join(model_dir, b.utils.ycb_loader.MODEL_NAMES[i], \"textured.obj\")\n", - "mesh = o3d.io.read_triangle_model(mesh_path)" + "mesh = o3d.io.read_triangle_model(mesh_path)" ] }, { @@ -121,7 +117,12 @@ "source": [ "%%time\n", "visualizer.render.scene.clear_geometry()\n", - "pose = b.distributions.gaussian_vmf(jax.random.PRNGKey(1000),b.transform_from_pos(jnp.array([0.0, 0.0, 0.3])), 0.01, 1.0)\n", + "pose = b.distributions.gaussian_vmf(\n", + " jax.random.PRNGKey(1000),\n", + " b.transform_from_pos(jnp.array([0.0, 0.0, 0.3])),\n", + " 0.01,\n", + " 1.0,\n", + ")\n", "mesh.meshes[0].mesh.transform(pose)\n", "visualizer.render.scene.add_model(f\"m\", mesh)\n", "mesh.meshes[0].mesh.transform(jnp.linalg.inv(pose))\n", @@ -135,8 +136,8 @@ "metadata": {}, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(5).rjust(6, '0') + \".ply\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(5).rjust(6, \"0\") + \".ply\")\n", "mesh = b.utils.load_mesh(mesh_path)\n", "vertices = jnp.array(mesh.vertices) / 1000.0" ] @@ -148,15 +149,22 @@ "outputs": [], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import math\n", "import torch\n", "from diff_gaussian_rasterization import _C\n", "import numpy as np\n", "import functools\n", + "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", - "def gaussian_raster_fwd(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics):\n", + "\n", + "def gaussian_raster_fwd(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + "):\n", " fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2\n", " fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2\n", " tan_fovx = math.tan(fovX)\n", @@ -164,14 +172,16 @@ "\n", " means3D = b.utils.jax_to_torch(means3D)\n", " N = means3D.shape[0]\n", - " means2D = torch.ones((N, 3),requires_grad=True, device=device)\n", - " \n", + " means2D = torch.ones((N, 3), requires_grad=True, device=device)\n", + "\n", " opacities = b.utils.jax_to_torch(opacities)\n", " scales = b.utils.jax_to_torch(scales)\n", " rotations = b.utils.jax_to_torch(rotations)\n", " colors_precomp = b.utils.jax_to_torch(colors_precomp)\n", "\n", - " view_matrix = torch.transpose(torch.tensor(np.array(b.inverse_pose(camera_pose))),0,1).cuda()\n", + " view_matrix = torch.transpose(\n", + " torch.tensor(np.array(b.inverse_pose(camera_pose))), 0, 1\n", + " ).cuda()\n", "\n", " def getProjectionMatrix(intrinsics):\n", " top = intrinsics.near / intrinsics.fy * intrinsics.height / 2.0\n", @@ -187,8 +197,16 @@ " P[1, 1] = 2.0 * intrinsics.near / (top - bottom)\n", " P[0, 2] = (right + left) / (right - left)\n", " P[1, 2] = (top + bottom) / (top - bottom)\n", - " P[2, 2] = z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)\n", - " P[2, 3] = -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " P[2, 2] = (\n", + " z_sign\n", + " * (intrinsics.far + intrinsics.near)\n", + " / (intrinsics.far - intrinsics.near)\n", + " )\n", + " P[2, 3] = (\n", + " -2.0\n", + " * (intrinsics.far * intrinsics.near)\n", + " / (intrinsics.far - intrinsics.near)\n", + " )\n", " P[3, 2] = z_sign\n", " return torch.transpose(P, 0, 1)\n", "\n", @@ -206,13 +224,13 @@ " sh_degree=1,\n", " campos=torch.zeros(3).cuda(),\n", " prefiltered=False,\n", - " debug=None\n", + " debug=None,\n", " )\n", " cov3Ds_precomp = torch.Tensor([])\n", " sh = torch.Tensor([])\n", " # (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: torch.Tensor, arg6: float, arg7: torch.Tensor, arg8: torch.Tensor, arg9: torch.Tensor, arg10: float, arg11: float, arg12: int, arg13: int, arg14: torch.Tensor, arg15: int, arg16: torch.Tensor, arg17: bool, arg18: bool)\n", " args = (\n", - " raster_settings.bg, \n", + " raster_settings.bg,\n", " means3D,\n", " colors_precomp,\n", " opacities,\n", @@ -230,44 +248,64 @@ " raster_settings.sh_degree,\n", " raster_settings.campos,\n", " raster_settings.prefiltered,\n", - " raster_settings.debug\n", + " raster_settings.debug,\n", " )\n", "\n", - " num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)\n", + " num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = (\n", + " _C.rasterize_gaussians(*args)\n", + " )\n", " return b.utils.torch_to_jax(color), (num_rendered,)\n", + "\n", + "\n", "# num_rendered, colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, raster_settings)\n", "\n", + "\n", "def gaussian_raster_bwd(saved_tensors, grad_out_color):\n", " # (num_rendered, colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, raster_settings) = saved_tensors\n", " return None, None, None\n", " # Restructure args as C++ method expects them\n", - " args = (raster_settings.bg,\n", - " means3D, \n", - " radii, \n", - " colors_precomp, \n", - " scales, \n", - " rotations, \n", - " raster_settings.scale_modifier, \n", - " cov3Ds_precomp, \n", - " raster_settings.viewmatrix, \n", - " raster_settings.projmatrix, \n", - " raster_settings.tanfovx, \n", - " raster_settings.tanfovy, \n", - " b.utils.jax_to_torch(grad_out_color), \n", - " sh, \n", - " raster_settings.sh_degree, \n", - " raster_settings.campos,\n", - " geomBuffer,\n", - " num_rendered,\n", - " binningBuffer,\n", - " imgBuffer,\n", - " raster_settings.debug)\n", - " grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)\n", - " return (b.utils.torch_to_jax(grad_means3D),\n", + " args = (\n", + " raster_settings.bg,\n", + " means3D,\n", + " radii,\n", + " colors_precomp,\n", + " scales,\n", + " rotations,\n", + " raster_settings.scale_modifier,\n", + " cov3Ds_precomp,\n", + " raster_settings.viewmatrix,\n", + " raster_settings.projmatrix,\n", + " raster_settings.tanfovx,\n", + " raster_settings.tanfovy,\n", + " b.utils.jax_to_torch(grad_out_color),\n", + " sh,\n", + " raster_settings.sh_degree,\n", + " raster_settings.campos,\n", + " geomBuffer,\n", + " num_rendered,\n", + " binningBuffer,\n", + " imgBuffer,\n", + " raster_settings.debug,\n", + " )\n", + " (\n", + " grad_means2D,\n", + " grad_colors_precomp,\n", + " grad_opacities,\n", + " grad_means3D,\n", + " grad_cov3Ds_precomp,\n", + " grad_sh,\n", + " grad_scales,\n", + " grad_rotations,\n", + " ) = _C.rasterize_gaussians_backward(*args)\n", + " return (\n", + " b.utils.torch_to_jax(grad_means3D),\n", " b.utils.torch_to_jax(grad_opacities),\n", " b.utils.torch_to_jax(grad_scales),\n", " b.utils.torch_to_jax(grad_rotations),\n", - " b.utils.torch_to_jax(grad_colors_precomp),None, None)\n", + " b.utils.torch_to_jax(grad_colors_precomp),\n", + " None,\n", + " None,\n", + " )\n", "\n", " # grads = (\n", " # grad_means3D,\n", @@ -280,12 +318,24 @@ " # grad_cov3Ds_precomp,\n", " # None,\n", " # )\n", - "@functools.partial(jax.custom_vjp, nondiff_argnums=(5,6,))\n", - "def gaussian_raster(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics):\n", - " return gaussian_raster_fwd(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics)[0]\n", "\n", - "gaussian_raster.defvjp(gaussian_raster_fwd , gaussian_raster_bwd)\n", - "\n" + "\n", + "@functools.partial(\n", + " jax.custom_vjp,\n", + " nondiff_argnums=(\n", + " 5,\n", + " 6,\n", + " ),\n", + ")\n", + "def gaussian_raster(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + "):\n", + " return gaussian_raster_fwd(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + " )[0]\n", + "\n", + "\n", + "gaussian_raster.defvjp(gaussian_raster_fwd, gaussian_raster_bwd)" ] }, { @@ -301,7 +351,7 @@ "metadata": {}, "outputs": [], "source": [ - "gt_color = jnp.transpose(rgbd.rgb, (2,0,1))[:3,...]" + "gt_color = jnp.transpose(rgbd.rgb, (2, 0, 1))[:3, ...]" ] }, { @@ -333,21 +383,49 @@ "# quat = jax.random.uniform(jax.random.PRNGKey(31), (4,))\n", "# transform = b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), jnp.array([0.0, 0.0, 0.5]))\n", "means3D = b.apply_transform_jit(vertices, pose)\n", - "opacities = jnp.ones((means3D.shape[0],1), dtype=jnp.float32) * 10.0\n", - "scales = jnp.ones((means3D.shape[0],3), dtype=jnp.float32) * -10.0\n", - "rotations = jnp.ones((means3D.shape[0],4), dtype=jnp.float32)\n", - "colors_precomp = jnp.ones((means3D.shape[0],3), dtype=jnp.float32) * 255.0\n", + "opacities = jnp.ones((means3D.shape[0], 1), dtype=jnp.float32) * 10.0\n", + "scales = jnp.ones((means3D.shape[0], 3), dtype=jnp.float32) * -10.0\n", + "rotations = jnp.ones((means3D.shape[0], 4), dtype=jnp.float32)\n", + "colors_precomp = jnp.ones((means3D.shape[0], 3), dtype=jnp.float32) * 255.0\n", "camera_pose = jnp.eye(4)\n", "\n", - "def render(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics):\n", - " color = gaussian_raster(means3D, jax.nn.sigmoid(opacities), jnp.exp(scales), rotations, colors_precomp, camera_pose, intrinsics)\n", - " return color \n", - "def loss(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics):\n", - " color = gaussian_raster(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics)\n", + "\n", + "def render(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + "):\n", + " color = gaussian_raster(\n", + " means3D,\n", + " jax.nn.sigmoid(opacities),\n", + " jnp.exp(scales),\n", + " rotations,\n", + " colors_precomp,\n", + " camera_pose,\n", + " intrinsics,\n", + " )\n", + " return color\n", + "\n", + "\n", + "def loss(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + "):\n", + " color = gaussian_raster(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + " )\n", " return jnp.mean(jnp.abs(color - gt_color))\n", - "grad = jax.grad(loss, argnums=(0,1,2,))\n", - "color = render(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics)\n", - "b.get_rgb_image(jnp.transpose(color, (1,2,0)))" + "\n", + "\n", + "grad = jax.grad(\n", + " loss,\n", + " argnums=(\n", + " 0,\n", + " 1,\n", + " 2,\n", + " ),\n", + ")\n", + "color = render(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + ")\n", + "b.get_rgb_image(jnp.transpose(color, (1, 2, 0)))" ] }, { @@ -499,7 +577,9 @@ ], "source": [ "%%time\n", - "color = gaussian_raster_fwd_jax(means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics)" + "color = gaussian_raster_fwd_jax(\n", + " means3D, opacities, scales, rotations, colors_precomp, camera_pose, intrinsics\n", + ")" ] }, { @@ -527,7 +607,7 @@ } ], "source": [ - "b.get_depth_image(color[0,...])" + "b.get_depth_image(color[0, ...])" ] }, { @@ -557,6 +637,7 @@ ], "source": [ "import diff_gaussian_rasterization\n", + "\n", "diff_gaussian_rasterization" ] }, @@ -594,7 +675,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"1\",jnp.array(mesh.meshes[0].mesh.vertices))" + "b.show_cloud(\"1\", jnp.array(mesh.meshes[0].mesh.vertices))" ] }, { diff --git a/scripts/experiments/gaussian_splatting/3dgs_jax.ipynb b/scripts/experiments/gaussian_splatting/3dgs_jax.ipynb index 2467c70a..8c99f44f 100644 --- a/scripts/experiments/gaussian_splatting/3dgs_jax.ipynb +++ b/scripts/experiments/gaussian_splatting/3dgs_jax.ipynb @@ -7,7 +7,10 @@ "outputs": [], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import torch\n", "import os\n", "import numpy as np\n", @@ -20,10 +23,12 @@ "import jax\n", "from jax.scipy.spatial.transform import Rotation as R\n", "import bayes3d.utils.gaussian_splatting as bgs\n", + "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device\n", "import bayes3d.utils.gaussian_splatting as bgs\n", - "key = jax.random.PRNGKey(0)\n" + "\n", + "key = jax.random.PRNGKey(0)" ] }, { @@ -65,41 +70,54 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=300.0, fy=300.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=2.5\n", + " height=200, width=200, fx=300.0, fy=300.0, cx=100.0, cy=100.0, near=0.01, far=2.5\n", ")\n", "key = jax.random.split(key, 1)[0]\n", - "means3D = jax.random.uniform(key, (10,3), minval=-0.5, maxval=0.5) + jnp.array([0.0, 0.0, 1.0])\n", + "means3D = jax.random.uniform(key, (10, 3), minval=-0.5, maxval=0.5) + jnp.array(\n", + " [0.0, 0.0, 1.0]\n", + ")\n", "N = len(means3D)\n", - "opacity = jnp.ones((N,1))\n", + "opacity = jnp.ones((N, 1))\n", "key = jax.random.split(key, 1)[0]\n", - "colors_precomp = jax.random.uniform(key, (N,3))\n", - "scales = -8.0 * jnp.ones((N,3))\n", - "rotations = -10.0 * jnp.ones((N,4))\n", + "colors_precomp = jax.random.uniform(key, (N, 3))\n", + "scales = -8.0 * jnp.ones((N, 3))\n", + "rotations = -10.0 * jnp.ones((N, 4))\n", "camera_pose = jnp.eye(4)\n", - "color_gt = bgs.gaussian_raster(means3D,\n", - " colors_precomp, jax.nn.sigmoid(opacity), jnp.exp(scales), rotations, camera_pose, intrinsics)\n", + "color_gt = bgs.gaussian_raster(\n", + " means3D,\n", + " colors_precomp,\n", + " jax.nn.sigmoid(opacity),\n", + " jnp.exp(scales),\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", + ")\n", "# gradients = bgs.gaussian_raster_bwd(saved_tensors, color_gt)\n", - "gt_viz = b.get_rgb_image(jnp.transpose(color_gt, (1,2,0))[...,:3] * 255.0)\n", + "gt_viz = b.get_rgb_image(jnp.transpose(color_gt, (1, 2, 0))[..., :3] * 255.0)\n", "\n", "key = jax.random.split(key, 1)[0]\n", - "means3D = jax.random.uniform(key, (20,3), minval=-0.5, maxval=0.5) + jnp.array([0.0, 0.0, 1.0])\n", + "means3D = jax.random.uniform(key, (20, 3), minval=-0.5, maxval=0.5) + jnp.array(\n", + " [0.0, 0.0, 1.0]\n", + ")\n", "N = len(means3D)\n", - "opacity = jnp.ones((N,1))\n", + "opacity = jnp.ones((N, 1))\n", "key = jax.random.split(key, 1)[0]\n", - "colors_precomp = jax.random.uniform(key, (N,3))\n", - "scales = -8.0 * jnp.ones((N,3))\n", - "rotations = -10.0 * jnp.ones((N,4))\n", + "colors_precomp = jax.random.uniform(key, (N, 3))\n", + "scales = -8.0 * jnp.ones((N, 3))\n", + "rotations = -10.0 * jnp.ones((N, 4))\n", "camera_pose = jnp.eye(4)\n", - "color = bgs.gaussian_raster(means3D,\n", - " colors_precomp, jax.nn.sigmoid(opacity), jnp.exp(scales), rotations, camera_pose, intrinsics)\n", - "b.hstack_images([\n", - " b.get_rgb_image(jnp.transpose(color, (1,2,0))[...,:3] * 255.0),\n", - " gt_viz\n", - "])" + "color = bgs.gaussian_raster(\n", + " means3D,\n", + " colors_precomp,\n", + " jax.nn.sigmoid(opacity),\n", + " jnp.exp(scales),\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", + ")\n", + "b.hstack_images(\n", + " [b.get_rgb_image(jnp.transpose(color, (1, 2, 0))[..., :3] * 255.0), gt_viz]\n", + ")" ] }, { @@ -135,20 +153,55 @@ ], "source": [ "import jax.example_libraries.optimizers as optimizers\n", + "\n", + "\n", "# optimizer = torch.optim.Adam([\n", "# {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", "# {'params': [means3D], 'lr': 0.001, \"name\": \"pos\"},\n", "# ], lr=0.0, eps=1e-15)\n", - "def loss(means3D, colors_precomp, opacity, scales, rotations, camera_pose, intrinsics, color_gt):\n", - " color = bgs.gaussian_raster(means3D,\n", - " colors_precomp, jax.nn.sigmoid(opacity), jnp.exp(scales), rotations, camera_pose, intrinsics)\n", + "def loss(\n", + " means3D,\n", + " colors_precomp,\n", + " opacity,\n", + " scales,\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", + " color_gt,\n", + "):\n", + " color = bgs.gaussian_raster(\n", + " means3D,\n", + " colors_precomp,\n", + " jax.nn.sigmoid(opacity),\n", + " jnp.exp(scales),\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", + " )\n", " return jnp.mean(jnp.abs(color - color_gt))\n", - "grad_func = jax.value_and_grad(loss, argnums=(0,1,))\n", - "loss, (means_grad, colors_precomp_grad) = grad_func(means3D, colors_precomp, opacity, scales, rotations, camera_pose, intrinsics, color_gt)\n", + "\n", + "\n", + "grad_func = jax.value_and_grad(\n", + " loss,\n", + " argnums=(\n", + " 0,\n", + " 1,\n", + " ),\n", + ")\n", + "loss, (means_grad, colors_precomp_grad) = grad_func(\n", + " means3D,\n", + " colors_precomp,\n", + " opacity,\n", + " scales,\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", + " color_gt,\n", + ")\n", "\n", "\n", "opt_init, opt_update, get_params = optimizers.adam(0.01, eps=1e-15)\n", - "opt_init2, opt_update2, get_params2 = optimizers.adam(0.001,eps=1e-15)\n", + "opt_init2, opt_update2, get_params2 = optimizers.adam(0.001, eps=1e-15)\n", "opt_state = opt_init(colors_precomp)\n", "opt_state2 = opt_init2(means3D)\n", "\n", @@ -156,17 +209,32 @@ "for i in pbar:\n", " colors_precomp = get_params(opt_state)\n", " means3D = get_params2(opt_state2)\n", - " loss, (means_grad, colors_precomp_grad) = grad_func(means3D, colors_precomp, opacity, scales, rotations, camera_pose, intrinsics, color_gt)\n", + " loss, (means_grad, colors_precomp_grad) = grad_func(\n", + " means3D,\n", + " colors_precomp,\n", + " opacity,\n", + " scales,\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", + " color_gt,\n", + " )\n", " opt_state = opt_update(i, colors_precomp_grad, opt_state)\n", " opt_state2 = opt_update2(i, means_grad, opt_state2)\n", " pbar.set_description(f\"loss: {loss}\")\n", "\n", - "color = bgs.gaussian_raster(means3D,\n", - " colors_precomp, jax.nn.sigmoid(opacity), jnp.exp(scales), rotations, camera_pose, intrinsics)\n", - "b.hstack_images([\n", - " b.get_rgb_image(jnp.transpose(color, (1,2,0))[...,:3] * 255.0),\n", - " gt_viz\n", - "])" + "color = bgs.gaussian_raster(\n", + " means3D,\n", + " colors_precomp,\n", + " jax.nn.sigmoid(opacity),\n", + " jnp.exp(scales),\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", + ")\n", + "b.hstack_images(\n", + " [b.get_rgb_image(jnp.transpose(color, (1, 2, 0))[..., :3] * 255.0), gt_viz]\n", + ")" ] }, { @@ -176,7 +244,7 @@ "outputs": [], "source": [ "r3d_path = \"/home/nishadgothoskar/bayes3d/scripts/2023-09-05--16-08-21.r3d.zip\"\n", - "colors, depths, poses, intrinsics, intrinsics_depth = b.utils.load_r3d(r3d_path);" + "colors, depths, poses, intrinsics, intrinsics_depth = b.utils.load_r3d(r3d_path)" ] }, { @@ -201,26 +269,38 @@ ], "source": [ "T = 0\n", - "color_gt = jnp.transpose(b.utils.resize(colors[T],intrinsics.height, intrinsics.width)[...,:3],(2,0,1)) / 255.0\n", - "gt_viz = b.get_rgb_image(jnp.transpose(color_gt, (1,2,0))[...,:3] * 255.0)\n", - "cloud = b.apply_transform(b.unproject_depth_jit(depths[T], intrinsics_depth), jnp.eye(4)).reshape(-1,3)\n", + "color_gt = (\n", + " jnp.transpose(\n", + " b.utils.resize(colors[T], intrinsics.height, intrinsics.width)[..., :3],\n", + " (2, 0, 1),\n", + " )\n", + " / 255.0\n", + ")\n", + "gt_viz = b.get_rgb_image(jnp.transpose(color_gt, (1, 2, 0))[..., :3] * 255.0)\n", + "cloud = b.apply_transform(\n", + " b.unproject_depth_jit(depths[T], intrinsics_depth), jnp.eye(4)\n", + ").reshape(-1, 3)\n", "means3D = cloud\n", "N = cloud.shape[0]\n", - "opacity = jnp.ones((N,1))\n", - "colors_precomp = jax.random.uniform(key, (N,3))\n", - "scales = -8.0 * jnp.ones((N,3))\n", - "rotations = -10.0 * jnp.ones((N,4))\n", + "opacity = jnp.ones((N, 1))\n", + "colors_precomp = jax.random.uniform(key, (N, 3))\n", + "scales = -8.0 * jnp.ones((N, 3))\n", + "rotations = -10.0 * jnp.ones((N, 4))\n", "camera_pose = jnp.eye(4)\n", "N = means3D.shape[0]\n", "\n", "color = bgs.gaussian_raster(\n", " means3D,\n", - " colors_precomp, jax.nn.sigmoid(opacity), jnp.exp(scales), rotations, camera_pose, intrinsics\n", + " colors_precomp,\n", + " jax.nn.sigmoid(opacity),\n", + " jnp.exp(scales),\n", + " rotations,\n", + " camera_pose,\n", + " intrinsics,\n", ")\n", - "b.hstack_images([\n", - " b.get_rgb_image(jnp.transpose(color, (1,2,0))[...,:3] * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [b.get_rgb_image(jnp.transpose(color, (1, 2, 0))[..., :3] * 255.0), gt_viz]\n", + ")" ] }, { @@ -319,56 +399,70 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=300.0, fy=300.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=2.5\n", + " height=200, width=200, fx=300.0, fy=300.0, cx=100.0, cy=100.0, near=0.01, far=2.5\n", ")\n", "rasterizer = bgs.intrinsics_to_rasterizer(intrinsics, jnp.eye(4))\n", - "means3D = torch.tensor(torch.rand((10,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means3D = torch.tensor(\n", + " torch.rand((10, 3)) - 0.5 + torch.tensor([0.0, 0.0, 1.0]),\n", + " requires_grad=True,\n", + " device=device,\n", + ")\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp_gt = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device).detach()\n", - "scales = torch.tensor(-8.0 * torch.ones((N,3)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp_gt = torch.tensor(\n", + " torch.rand((N, 3)), requires_grad=True, device=device\n", + ").detach()\n", + "scales = torch.tensor(-8.0 * torch.ones((N, 3)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", "\n", - "color,radii = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp_gt,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + "color, radii = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp_gt,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", ")\n", "print(colors_precomp_gt)\n", "color_gt = color.detach()\n", - "gt_viz = b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "gt_viz = b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0)\n", "gt_viz\n", "\n", - "means3D = torch.tensor(torch.rand((20,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means3D = torch.tensor(\n", + " torch.rand((20, 3)) - 0.5 + torch.tensor([0.0, 0.0, 1.0]),\n", + " requires_grad=True,\n", + " device=device,\n", + ")\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-8.0 * torch.ones((N,3)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", - "color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-8.0 * torch.ones((N, 3)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", + "color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", ")\n", "\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -396,31 +490,39 @@ } ], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [means3D], 'lr': 0.001, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [means3D], \"lr\": 0.001, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", - " loss = torch.abs((color_gt - color)[:3,...]).mean()\n", + " loss = torch.abs((color_gt - color)[:3, ...]).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(\n", + " torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0\n", + " ),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -464,36 +566,54 @@ "source": [ "T = 0\n", "rasterizer = intrinsics_to_rasterizer(intrinsics, poses[T])\n", - "color_gt = (torch.tensor(np.array(b.utils.resize(colors[T],intrinsics.height, intrinsics.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "color_gt = (\n", + " torch.tensor(\n", + " np.array(b.utils.resize(colors[T], intrinsics.height, intrinsics.width))[\n", + " ..., :3\n", + " ]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + ").cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy() * 255.0\n", + ")\n", "gt_viz\n", "\n", "# means3D a= torch.tensor(torch.rand((1000,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "cloud = np.array(b.apply_transform(b.unproject_depth_jit(depths[T], intrinsics_depth), poses[T]).reshape(-1,3))\n", + "cloud = np.array(\n", + " b.apply_transform(\n", + " b.unproject_depth_jit(depths[T], intrinsics_depth), poses[T]\n", + " ).reshape(-1, 3)\n", + ")\n", "# cloud = cloud[cloud[:,2] < 1.0,:]\n", "# choice = jax.random.choice(jax.random.PRNGKey(1000), cloud.shape[0], shape=(12000,), replace=False)\n", "# cloud = cloud[choice]\n", "means3D = torch.tensor(cloud, requires_grad=True, device=device)\n", "\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-15.0 * torch.ones((N,1)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-15.0 * torch.ones((N, 1)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", "\n", - "color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + "color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", ")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " ]\n", + ")" ] }, { @@ -530,38 +650,53 @@ "source": [ "T = 50\n", "rasterizer = intrinsics_to_rasterizer(intrinsics, poses[T])\n", - "color_gt = (torch.tensor(np.array(b.utils.resize(colors[T],intrinsics.height, intrinsics.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "color_gt = (\n", + " torch.tensor(\n", + " np.array(b.utils.resize(colors[T], intrinsics.height, intrinsics.width))[\n", + " ..., :3\n", + " ]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + ").cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy() * 255.0\n", + ")\n", "gt_viz\n", "\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", - " {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", + " {\"params\": [scales], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [rotations], \"lr\": 1.0, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(1))\n", "for _ in pbar:\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -577,34 +712,40 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", - " {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", + " {\"params\": [scales], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [rotations], \"lr\": 1.0, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -628,6 +769,7 @@ ], "source": [ "from bayes3d.viz.open3dviz import Open3DVisualizer\n", + "\n", "visualizer = Open3DVisualizer(intrinsics)" ] }, @@ -651,12 +793,18 @@ "source": [ "import os\n", "import open3d as o3d\n", + "\n", "model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", "i = 4\n", "mesh_path = os.path.join(model_dir, b.utils.ycb_loader.MODEL_NAMES[i], \"textured.obj\")\n", - "mesh = o3d.io.read_triangle_model(mesh_path)\n", + "mesh = o3d.io.read_triangle_model(mesh_path)\n", "visualizer.render.scene.clear_geometry()\n", - "pose = b.distributions.gaussian_vmf(jax.random.PRNGKey(1000),b.transform_from_pos(jnp.array([0.0, 0.0, 0.3])), 0.01, 1.0)\n", + "pose = b.distributions.gaussian_vmf(\n", + " jax.random.PRNGKey(1000),\n", + " b.transform_from_pos(jnp.array([0.0, 0.0, 0.3])),\n", + " 0.01,\n", + " 1.0,\n", + ")\n", "mesh.meshes[0].mesh.transform(pose)\n", "visualizer.render.scene.add_model(f\"m\", mesh)\n", "mesh.meshes[0].mesh.transform(jnp.linalg.inv(pose))\n", @@ -696,37 +844,44 @@ } ], "source": [ - "color_gt = (torch.tensor(np.array(rgbd.rgb)[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "color_gt = (torch.tensor(np.array(rgbd.rgb)[..., :3]).permute(2, 0, 1) / 255.0).cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy() * 255.0\n", + ")\n", "gt_viz\n", "\n", "# means3D = torch.tensor(torch.rand((1000,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "cloud = np.array(b.unproject_depth_jit(rgbd.depth, rgbd.intrinsics).reshape(-1,3))\n", - "cloud = cloud[cloud[:,2] < intrinsics.far,:]\n", + "cloud = np.array(b.unproject_depth_jit(rgbd.depth, rgbd.intrinsics).reshape(-1, 3))\n", + "cloud = cloud[cloud[:, 2] < intrinsics.far, :]\n", "# choice = jax.random.choice(jax.random.PRNGKey(1000), cloud.shape[0], shape=(100,), replace=False)\n", "# cloud = cloud[choice]\n", "means3D = torch.tensor(cloud, requires_grad=True, device=device)\n", "\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-12.0 * torch.ones((N,3)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-12.0 * torch.ones((N, 3)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", "\n", - "color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = torch.sigmoid(colors_precomp),\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + "color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=torch.sigmoid(colors_precomp),\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", ")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -754,33 +909,39 @@ } ], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", - " {'params': [scales], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [rotations], 'lr': 0.1, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", + " {\"params\": [scales], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [rotations], \"lr\": 0.1, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = torch.sigmoid(colors_precomp),\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=torch.sigmoid(colors_precomp),\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -789,13 +950,15 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy(),\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.001)\n", - " # color=jnp.transpose(color.reshape(-1,3))/255.0, \n", + " size=0.001,\n", + ")\n", + "# color=jnp.transpose(color.reshape(-1,3))/255.0,\n", "\n", - " # size=0.003)" + "# size=0.003)" ] }, { @@ -871,13 +1034,15 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy() * 0.3,\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.002)\n", - " # color=jnp.transpose(color.reshape(-1,3))/255.0, \n", + " size=0.002,\n", + ")\n", + "# color=jnp.transpose(color.reshape(-1,3))/255.0,\n", "\n", - " # size=0.003)" + "# size=0.003)" ] }, { @@ -919,48 +1084,70 @@ } ], "source": [ - "color_gt = (torch.tensor(np.array(b.utils.resize(colors[20],intrinsics.height, intrinsics.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "color_gt = (\n", + " torch.tensor(\n", + " np.array(b.utils.resize(colors[20], intrinsics.height, intrinsics.width))[\n", + " ..., :3\n", + " ]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + ").cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy() * 255.0\n", + ")\n", "\n", "\n", "pos = torch.tensor([0.0, 0.0, 0.0], device=device, requires_grad=True)\n", "quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, requires_grad=True)\n", "\n", "\n", - "optimizer = torch.optim.Adam([\n", - " # {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [pos], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " # {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " # {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", + " {\"params\": [pos], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", + " # {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", " pose = pytorch3d.transforms.Transform3d(\n", - " matrix=torch.vstack([torch.hstack([pytorch3d.transforms.quaternion_to_matrix(quat), pos.reshape(3,1)]), torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)])\n", + " matrix=torch.vstack(\n", + " [\n", + " torch.hstack(\n", + " [pytorch3d.transforms.quaternion_to_matrix(quat), pos.reshape(3, 1)]\n", + " ),\n", + " torch.tensor([0.0, 0.0, 0.0, 1.0], device=device),\n", + " ]\n", + " )\n", " )\n", " pose.transform_points(means3D)\n", - " color,_ = rasterizer(\n", - " means3D = pose.transform_points(means3D),\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=pose.transform_points(means3D),\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -982,8 +1169,8 @@ ], "source": [ "b.overlay_image(\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", ")" ] }, @@ -1030,13 +1217,15 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy() * 0.3,\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.002)\n", - " # color=jnp.transpose(color.reshape(-1,3))/255.0, \n", + " size=0.002,\n", + ")\n", + "# color=jnp.transpose(color.reshape(-1,3))/255.0,\n", "\n", - " # size=0.003)" + "# size=0.003)" ] }, { diff --git a/scripts/experiments/gaussian_splatting/3dgs_tracking.ipynb b/scripts/experiments/gaussian_splatting/3dgs_tracking.ipynb index 9239c836..42bc2db3 100644 --- a/scripts/experiments/gaussian_splatting/3dgs_tracking.ipynb +++ b/scripts/experiments/gaussian_splatting/3dgs_tracking.ipynb @@ -18,7 +18,10 @@ ], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import torch\n", "import os\n", "import numpy as np\n", @@ -60,10 +63,10 @@ "metadata": {}, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(5).rjust(6, '0') + \".ply\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(5).rjust(6, \"0\") + \".ply\")\n", "mesh = b.utils.load_mesh(mesh_path)\n", - "vertices = torch.tensor(np.array(jnp.array(mesh.vertices) / 1000.0),device=device)" + "vertices = torch.tensor(np.array(jnp.array(mesh.vertices) / 1000.0), device=device)" ] }, { @@ -82,11 +85,7 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=300.0, fy=300.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=2.5\n", + " height=200, width=200, fx=300.0, fy=300.0, cx=100.0, cy=100.0, near=0.01, far=2.5\n", ")\n", "fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2\n", "fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2\n", @@ -108,11 +107,16 @@ " P[1, 1] = 2.0 * intrinsics.near / (top - bottom)\n", " P[0, 2] = (right + left) / (right - left)\n", " P[1, 2] = (top + bottom) / (top - bottom)\n", - " P[2, 2] = z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)\n", - " P[2, 3] = -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " P[2, 2] = (\n", + " z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " )\n", + " P[2, 3] = (\n", + " -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " )\n", " P[3, 2] = z_sign\n", " return torch.transpose(P, 0, 1)\n", "\n", + "\n", "proj_matrix = torch.tensor(getProjectionMatrix(intrinsics), device=device)" ] }, @@ -137,23 +141,33 @@ "def posevec_to_matrix(position, quat):\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quat), position.unsqueeze(1)), 1),\n", - " torch.tensor([[0.0, 0.0, 0.0, 1.0]],device=device),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quat),\n", + " position.unsqueeze(1),\n", + " ),\n", + " 1,\n", + " ),\n", + " torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device),\n", " ),\n", " 0,\n", " )\n", + "\n", + "\n", "def apply_transform(points, transform):\n", " rels_ = torch.cat(\n", " (\n", " points,\n", - " torch.ones((points.shape[0], 1), device=device),\n", + " torch.ones((points.shape[0], 1), device=device),\n", " ),\n", " 1,\n", " )\n", - " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[...,:3]\n", + " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[..., :3]\n", + "\n", + "\n", "position = torch.tensor([0.0, 0.1, 0.2], device=device)\n", - "quat = torch.tensor([1.0, 0.1, 0.2, 0.3],device=device)\n", - "points = torch.zeros((5,3), device = device)\n", + "quat = torch.tensor([1.0, 0.1, 0.2, 0.3], device=device)\n", + "points = torch.zeros((5, 3), device=device)\n", "print(apply_transform(points, posevec_to_matrix(position, quat)))" ] }, @@ -164,7 +178,9 @@ "outputs": [], "source": [ "camera_pose = jnp.eye(4)\n", - "view_matrix = torch.transpose(torch.tensor(np.array(b.inverse_pose(camera_pose))),0,1).cuda()\n", + "view_matrix = torch.transpose(\n", + " torch.tensor(np.array(b.inverse_pose(camera_pose))), 0, 1\n", + ").cuda()\n", "raster_settings = GaussianRasterizationSettings(\n", " image_height=int(intrinsics.height),\n", " image_width=int(intrinsics.width),\n", @@ -177,9 +193,9 @@ " sh_degree=1,\n", " campos=torch.zeros(3).cuda(),\n", " prefiltered=False,\n", - " debug=None\n", + " debug=None,\n", ")\n", - "rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n" + "rasterizer = GaussianRasterizer(raster_settings=raster_settings)" ] }, { @@ -195,22 +211,22 @@ "metadata": {}, "outputs": [], "source": [ - "def render(pos,quat):\n", + "def render(pos, quat):\n", " means3D = apply_transform(vertices, posevec_to_matrix(pos, quat))\n", " N = means3D.shape[0]\n", - " means2D = torch.ones((N, 3),requires_grad=True, device=device)\n", - " opacity = torch.ones((N, 1),requires_grad=True,device=device)\n", - " scales = torch.tensor( 0.005 * torch.rand((N, 3)),requires_grad=True,device=device)\n", - " rotations = torch.rand((N, 4),requires_grad=True,device=device)\n", + " means2D = torch.ones((N, 3), requires_grad=True, device=device)\n", + " opacity = torch.ones((N, 1), requires_grad=True, device=device)\n", + " scales = torch.tensor(0.005 * torch.rand((N, 3)), requires_grad=True, device=device)\n", + " rotations = torch.rand((N, 4), requires_grad=True, device=device)\n", "\n", - " color, radii= rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = means3D[:,2:3].repeat(1,3),\n", - " opacities = opacity,\n", - " scales = scales,\n", - " rotations = rotations\n", + " color, radii = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=means3D[:, 2:3].repeat(1, 3),\n", + " opacities=opacity,\n", + " scales=scales,\n", + " rotations=rotations,\n", " )\n", " return color" ] @@ -237,7 +253,7 @@ ], "source": [ "b.setup_renderer(intrinsics)\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n" + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)" ] }, { @@ -247,17 +263,24 @@ "outputs": [], "source": [ "quat = jax.random.uniform(jax.random.PRNGKey(30), (4,))\n", - "poses = [b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), jnp.array([0.0, 0.0, 0.5]))]\n", + "poses = [\n", + " b.transform_from_rot_and_pos(\n", + " b.quaternion_to_rotation_matrix(quat), jnp.array([0.0, 0.0, 0.5])\n", + " )\n", + "]\n", "delta_pose = b.t3d.transform_from_rot_and_pos(\n", - " R.from_euler('zyx', [2.0, 1.1, 3.0], degrees=True).as_matrix(),\n", - " jnp.array([-0.0009, 0.0005, 0.0002])\n", + " R.from_euler(\"zyx\", [2.0, 1.1, 3.0], degrees=True).as_matrix(),\n", + " jnp.array([-0.0009, 0.0005, 0.0002]),\n", ")\n", "num_frames = 200\n", - "for t in range(num_frames-1):\n", + "for t in range(num_frames - 1):\n", " poses.append(poses[-1].dot(delta_pose))\n", "poses = jnp.stack(poses)\n", - "gt_images = b.RENDERER.render_many(poses[:,None,...], jnp.array([0]))\n", - "viz_gt_images = [b.get_depth_image(gt_images[i,...,2],max=intrinsics.far) for i in range(num_frames)]" + "gt_images = b.RENDERER.render_many(poses[:, None, ...], jnp.array([0]))\n", + "viz_gt_images = [\n", + " b.get_depth_image(gt_images[i, ..., 2], max=intrinsics.far)\n", + " for i in range(num_frames)\n", + "]" ] }, { @@ -290,31 +313,39 @@ } ], "source": [ - "pos = torch.tensor(np.array(poses[0][:3,3]),device=device, requires_grad=True)\n", - "quat = torch.tensor(np.array(b.rotation_matrix_to_quaternion(poses[0][:3,:3])),device=device, requires_grad=True)\n", + "pos = torch.tensor(np.array(poses[0][:3, 3]), device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " np.array(b.rotation_matrix_to_quaternion(poses[0][:3, :3])),\n", + " device=device,\n", + " requires_grad=True,\n", + ")\n", "\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [pos], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.025, \"name\": \"quat\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.025, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "\n", "viz_images_inferred = []\n", "losses_over_time = []\n", "pbar = tqdm(range(len(gt_images)))\n", "for timestep in pbar:\n", - " gt_image = torch.tensor(np.array(gt_images[timestep][...,2]),device=device)\n", + " gt_image = torch.tensor(np.array(gt_images[timestep][..., 2]), device=device)\n", " for _ in range(15):\n", - " rendered_image = render(pos, quat)\n", + " rendered_image = render(pos, quat)\n", " loss = torch.abs(gt_image - rendered_image).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - " depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - " viz = b.get_depth_image(depth_image,max=intrinsics.far)\n", + " depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + " viz = b.get_depth_image(depth_image, max=intrinsics.far)\n", " viz_images_inferred.append(viz)\n", - " losses_over_time.append(loss.item())\n" + " losses_over_time.append(loss.item())" ] }, { @@ -324,8 +355,14 @@ "outputs": [], "source": [ "b.make_gif_from_pil_images(\n", - " [b.multi_panel([b.viz.scale_image(c,2.0),b.viz.scale_image(d,2.0)],labels=[\"Original\", \"3DGS Tracking\"]) for (c,d) in zip(viz_gt_images, viz_images_inferred)],\n", - " \"test_inferred.gif\"\n", + " [\n", + " b.multi_panel(\n", + " [b.viz.scale_image(c, 2.0), b.viz.scale_image(d, 2.0)],\n", + " labels=[\"Original\", \"3DGS Tracking\"],\n", + " )\n", + " for (c, d) in zip(viz_gt_images, viz_images_inferred)\n", + " ],\n", + " \"test_inferred.gif\",\n", ")" ] }, @@ -336,9 +373,13 @@ "outputs": [], "source": [ "timestep = 0\n", - "pos = torch.tensor(np.array(poses[0][:3,3]),device=device, requires_grad=True)\n", - "quat = torch.tensor(np.array(b.rotation_matrix_to_quaternion(poses[0][:3,:3])),device=device, requires_grad=True)\n", - "rendered_image = render(pos, quat)\n" + "pos = torch.tensor(np.array(poses[0][:3, 3]), device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " np.array(b.rotation_matrix_to_quaternion(poses[0][:3, :3])),\n", + " device=device,\n", + " requires_grad=True,\n", + ")\n", + "rendered_image = render(pos, quat)" ] }, { @@ -359,7 +400,7 @@ } ], "source": [ - "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", + "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", "viz = b.get_depth_image(depth_image)\n", "viz" ] @@ -384,12 +425,16 @@ } ], "source": [ - "pos = torch.tensor(states[0][0],device=device, requires_grad=True)\n", - "quat = torch.tensor(states[0][1],device=device, requires_grad=True)\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [pos], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.05, \"name\": \"quat\"},\n", - "], lr=0.0, eps=1e-15)\n", + "pos = torch.tensor(states[0][0], device=device, requires_grad=True)\n", + "quat = torch.tensor(states[0][1], device=device, requires_grad=True)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.05, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "\n", "viz_images_inferred = []\n", @@ -397,20 +442,23 @@ "pbar = tqdm(range(len(images)))\n", "for timestep in pbar:\n", " for _ in range(5):\n", - " rendered_image = render(pos, quat)\n", + " rendered_image = render(pos, quat)\n", " loss = torch.abs(images[timestep] - rendered_image).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", - " depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", + " depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", " viz = b.get_depth_image(depth_image)\n", " viz_images_inferred.append(viz)\n", " losses_over_time.append(loss.item())\n", " pbar.set_description(f\"{loss.item()}\")\n", "\n", "b.make_gif_from_pil_images(\n", - " [b.multi_panel([c,d],labels=[\"Original\", \"Tracked\"]) for (c,d) in zip(viz_images, viz_images_inferred)],\n", - " \"test_inferred.gif\"\n", + " [\n", + " b.multi_panel([c, d], labels=[\"Original\", \"Tracked\"])\n", + " for (c, d) in zip(viz_images, viz_images_inferred)\n", + " ],\n", + " \"test_inferred.gif\",\n", ")" ] }, @@ -451,17 +499,17 @@ } ], "source": [ - "\n", - "\n", - "pos = torch.tensor([0.0, 0.0, 0.5],device=device, requires_grad=True)\n", - "quat = torch.tensor(quat + torch.rand(4,device=device)*0.4,device=device, requires_grad=True)\n", - "rendered_image = render(pos, quat)\n", - "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "pos = torch.tensor([0.0, 0.0, 0.5], device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " quat + torch.rand(4, device=device) * 0.4, device=device, requires_grad=True\n", + ")\n", + "rendered_image = render(pos, quat)\n", + "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz = b.get_depth_image(depth_image)\n", "parameters_over_time = []\n", "losses_over_time = []\n", - "b.hstack_images([viz, viz_gt])\n" + "b.hstack_images([viz, viz_gt])" ] }, { @@ -491,29 +539,31 @@ } ], "source": [ - "\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [pos], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.1, \"name\": \"quat\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.1, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " rendered_image = render(pos, quat)\n", + " rendered_image = render(pos, quat)\n", " loss = torch.abs(gt_rendered_image - rendered_image).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", - " parameters_over_time.append((pos.detach().clone(),quat.detach().clone()))\n", + " parameters_over_time.append((pos.detach().clone(), quat.detach().clone()))\n", " losses_over_time.append(loss.item())\n", " pbar.set_description(f\"{loss.item()}\")\n", "\n", - "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz = b.get_depth_image(depth_image)\n", - "b.hstack_images([viz, viz_gt])\n", - "\n" + "b.hstack_images([viz, viz_gt])" ] }, { @@ -551,31 +601,31 @@ ], "source": [ "T = 0\n", - "fig = plt.figure(figsize=(6,6))\n", - "ax = fig.add_subplot(2,2,1)\n", + "fig = plt.figure(figsize=(6, 6))\n", + "ax = fig.add_subplot(2, 2, 1)\n", "ax.set_title(\"Target\")\n", - "gt_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "img1 = ax.imshow(b.preprocess_for_viz(gt_image),cmap=b.cmap)\n", - "ax = fig.add_subplot(2,2,2)\n", + "gt_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "img1 = ax.imshow(b.preprocess_for_viz(gt_image), cmap=b.cmap)\n", + "ax = fig.add_subplot(2, 2, 2)\n", "parameters = parameters_over_time[T]\n", "rendered_image = render(*parameters)\n", - "rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "img2 = ax.imshow(b.preprocess_for_viz(rendered_image),cmap=b.cmap)\n", + "rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "img2 = ax.imshow(b.preprocess_for_viz(rendered_image), cmap=b.cmap)\n", "title = ax.set_title(f\"Reconstruction\")\n", - "ax = fig.add_subplot(2,1,2)\n", + "ax = fig.add_subplot(2, 1, 2)\n", "line = ax.plot(jnp.arange(T), losses_over_time[:T])\n", "# ax.set_yscale(\"log\")\n", "ax.set_title(\"Pixelwise MSE Loss\")\n", "ax.set_ylim(-0.0001, 0.1)\n", "ax.set_xlabel(\"Iteration\")\n", - "ax.set_xlim(0,len(losses_over_time))\n", + "ax.set_xlim(0, len(losses_over_time))\n", "fig.tight_layout()\n", "\n", "buffs = []\n", - "for T in tqdm(range(0,len(losses_over_time),5)):\n", + "for T in tqdm(range(0, len(losses_over_time), 5)):\n", " parameters = parameters_over_time[T]\n", " rendered_image = render(*parameters)\n", - " rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", + " rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", " img2.set_array(b.preprocess_for_viz(rendered_image))\n", " line[0].set_xdata(jnp.arange(T))\n", " line[0].set_ydata(losses_over_time[:T])\n", diff --git a/scripts/experiments/gaussian_splatting/3dgs_validate_optimizing.ipynb b/scripts/experiments/gaussian_splatting/3dgs_validate_optimizing.ipynb index db84802b..5c56d8a7 100644 --- a/scripts/experiments/gaussian_splatting/3dgs_validate_optimizing.ipynb +++ b/scripts/experiments/gaussian_splatting/3dgs_validate_optimizing.ipynb @@ -18,7 +18,10 @@ ], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import torch\n", "import os\n", "import numpy as np\n", @@ -65,41 +68,47 @@ "\n", "\n", "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=300.0, fy=300.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=2.5\n", + " height=200, width=200, fx=300.0, fy=300.0, cx=100.0, cy=100.0, near=0.01, far=2.5\n", ")\n", "\n", + "\n", "def posevec_to_matrix(position, quat):\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quat), position.unsqueeze(1)), 1),\n", - " torch.tensor([[0.0, 0.0, 0.0, 1.0]],device=device),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quat),\n", + " position.unsqueeze(1),\n", + " ),\n", + " 1,\n", + " ),\n", + " torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device),\n", " ),\n", " 0,\n", " )\n", + "\n", + "\n", "def apply_transform(points, transform):\n", " rels_ = torch.cat(\n", " (\n", " points,\n", - " torch.ones((points.shape[0], 1), device=device),\n", + " torch.ones((points.shape[0], 1), device=device),\n", " ),\n", " 1,\n", " )\n", - " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[...,:3]\n", + " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[..., :3]\n", "\n", "\n", "def intrinsics_to_rasterizer(intrinsics, camera_pose_jax):\n", - "\n", " fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2.0\n", " fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2.0\n", " tan_fovx = math.tan(fovX)\n", " tan_fovy = math.tan(fovY)\n", "\n", - " proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0,1).cuda()\n", - " view_matrix = torch.transpose(torch.tensor(np.array(b.inverse_pose(camera_pose_jax))),0,1).cuda()\n", + " proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0, 1).cuda()\n", + " view_matrix = torch.transpose(\n", + " torch.tensor(np.array(b.inverse_pose(camera_pose_jax))), 0, 1\n", + " ).cuda()\n", "\n", " raster_settings = GaussianRasterizationSettings(\n", " image_height=int(intrinsics.height),\n", @@ -113,7 +122,7 @@ " sh_degree=0,\n", " campos=torch.zeros(3).cuda(),\n", " prefiltered=False,\n", - " debug=None\n", + " debug=None,\n", " )\n", " rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n", " return rasterizer" @@ -198,56 +207,74 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=300.0, fy=300.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=2.5\n", + " height=200, width=200, fx=300.0, fy=300.0, cx=100.0, cy=100.0, near=0.01, far=2.5\n", ")\n", "rasterizer = intrinsics_to_rasterizer(intrinsics, jnp.eye(4))\n", - "means3D = torch.tensor(torch.rand((10,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means3D = torch.tensor(\n", + " torch.rand((10, 3)) - 0.5 + torch.tensor([0.0, 0.0, 1.0]),\n", + " requires_grad=True,\n", + " device=device,\n", + ")\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp_gt = torch.tensor(torch.rand((N,5)), requires_grad=True, device=device).detach()\n", - "scales = torch.tensor(-8.0 * torch.ones((N,3)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", - "\n", - "color,radii = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp_gt,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp_gt = torch.tensor(\n", + " torch.rand((N, 5)), requires_grad=True, device=device\n", + ").detach()\n", + "scales = torch.tensor(-8.0 * torch.ones((N, 3)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", + "\n", + "color, radii = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp_gt,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", ")\n", "print(colors_precomp_gt)\n", "color_gt = color.detach()\n", - "gt_viz = b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0)\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0\n", + ")\n", "gt_viz\n", "\n", - "means3D = torch.tensor(torch.rand((20,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means3D = torch.tensor(\n", + " torch.rand((20, 3)) - 0.5 + torch.tensor([0.0, 0.0, 1.0]),\n", + " requires_grad=True,\n", + " device=device,\n", + ")\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,5)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-8.0 * torch.ones((N,3)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", - "color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 5)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-8.0 * torch.ones((N, 3)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", + "color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", ")\n", "\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(\n", + " torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0\n", + " ),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -275,31 +302,39 @@ } ], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [means3D], 'lr': 0.001, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [means3D], \"lr\": 0.001, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(\n", + " torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0\n", + " ),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -309,7 +344,7 @@ "outputs": [], "source": [ "r3d_path = \"/home/nishadgothoskar/bayes3d/scripts/2023-09-05--16-08-21.r3d.zip\"\n", - "colors, depths, poses, intrinsics, intrinsics_depth = b.utils.load_r3d(r3d_path);" + "colors, depths, poses, intrinsics, intrinsics_depth = b.utils.load_r3d(r3d_path)" ] }, { @@ -340,22 +375,44 @@ } ], "source": [ - "T =0\n", + "T = 0\n", "rasterizer = intrinsics_to_rasterizer(intrinsics_depth, jnp.eye(4))\n", - "color_gt = (torch.tensor(np.array(b.utils.resize(colors[T],intrinsics_depth.height, intrinsics_depth.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "depth_gt = (torch.tensor(np.array(b.utils.resize(depths[T],intrinsics_depth.height, intrinsics_depth.width)))).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0)\n", + "color_gt = (\n", + " torch.tensor(\n", + " np.array(\n", + " b.utils.resize(colors[T], intrinsics_depth.height, intrinsics_depth.width)\n", + " )[..., :3]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + ").cuda()\n", + "depth_gt = (\n", + " torch.tensor(\n", + " np.array(\n", + " b.utils.resize(depths[T], intrinsics_depth.height, intrinsics_depth.width)\n", + " )\n", + " )\n", + ").cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0\n", + ")\n", "gt_viz\n", "\n", "# means3D a= torch.tensor(torch.rand((1000,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "cloud = np.array(b.apply_transform(b.unproject_depth_jit(depths[T], intrinsics_depth), jnp.eye(4)).reshape(-1,3))\n", + "cloud = np.array(\n", + " b.apply_transform(\n", + " b.unproject_depth_jit(depths[T], intrinsics_depth), jnp.eye(4)\n", + " ).reshape(-1, 3)\n", + ")\n", "means3D = torch.tensor(cloud, requires_grad=True, device=device)\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-15.0 * torch.ones((N,1)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)" + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-15.0 * torch.ones((N, 1)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)" ] }, { @@ -390,43 +447,68 @@ } ], "source": [ - "T =0\n", + "T = 0\n", "rasterizer = intrinsics_to_rasterizer(intrinsics_depth, jnp.eye(4))\n", - "color_gt = (torch.tensor(np.array(b.utils.resize(colors[T],intrinsics_depth.height, intrinsics_depth.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "depth_gt = (torch.tensor(np.array(b.utils.resize(depths[T],intrinsics_depth.height, intrinsics_depth.width)))).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0)\n", + "color_gt = (\n", + " torch.tensor(\n", + " np.array(\n", + " b.utils.resize(colors[T], intrinsics_depth.height, intrinsics_depth.width)\n", + " )[..., :3]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + ").cuda()\n", + "depth_gt = (\n", + " torch.tensor(\n", + " np.array(\n", + " b.utils.resize(depths[T], intrinsics_depth.height, intrinsics_depth.width)\n", + " )\n", + " )\n", + ").cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0\n", + ")\n", "gt_viz\n", "\n", "# means3D a= torch.tensor(torch.rand((1000,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "cloud = np.array(b.apply_transform(b.unproject_depth_jit(depths[T], intrinsics_depth), jnp.eye(4)).reshape(-1,3))\n", + "cloud = np.array(\n", + " b.apply_transform(\n", + " b.unproject_depth_jit(depths[T], intrinsics_depth), jnp.eye(4)\n", + " ).reshape(-1, 3)\n", + ")\n", "# cloud = cloud[cloud[:,2] < 1.0,:]\n", "# choice = jax.random.choice(jax.random.PRNGKey(1000), cloud.shape[0], shape=(4000,), replace=False)\n", "# cloud = cloud[choice]\n", "means3D = torch.tensor(cloud, requires_grad=True, device=device)\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-15.0 * torch.ones((N,1)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", - "\n", - "expanded_color = torch.concatenate([\n", - " colors_precomp,\n", - " means3D[:,2:3],\n", - " torch.ones((means3D.shape[0],1), requires_grad=True, device=device)\n", - "],axis=-1)\n", - "\n", - "color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = expanded_color,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-15.0 * torch.ones((N, 1)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", + "\n", + "expanded_color = torch.concatenate(\n", + " [\n", + " colors_precomp,\n", + " means3D[:, 2:3],\n", + " torch.ones((means3D.shape[0], 1), requires_grad=True, device=device),\n", + " ],\n", + " axis=-1,\n", ")\n", "\n", - "b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0)" + "color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=expanded_color,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", + ")\n", + "\n", + "b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0)" ] }, { @@ -454,41 +536,54 @@ } ], "source": [ - "\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", - " {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", + " {\"params\": [scales], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [rotations], \"lr\": 1.0, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " expanded_color = torch.concatenate([\n", - " colors_precomp,\n", - " means3D[:,2:3],\n", - " torch.ones((means3D.shape[0],1), requires_grad=True, device=device)\n", - " ],axis=-1)\n", - "\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = expanded_color,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + " expanded_color = torch.concatenate(\n", + " [\n", + " colors_precomp,\n", + " means3D[:, 2:3],\n", + " torch.ones((means3D.shape[0], 1), requires_grad=True, device=device),\n", + " ],\n", + " axis=-1,\n", + " )\n", + "\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=expanded_color,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", - " loss = torch.abs(color_gt - color[:3,...]).mean() + torch.abs(depth_gt - color[3,...]).mean() \n", + " loss = (\n", + " torch.abs(color_gt - color[:3, ...]).mean()\n", + " + torch.abs(depth_gt - color[3, ...]).mean()\n", + " )\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(\n", + " torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0\n", + " ),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -499,10 +594,10 @@ "source": [ "b.clear_visualizer()\n", "depth = jnp.array(color[3, ...].cpu().detach().numpy())\n", - "cloud = b.unproject_depth_jit(depth, intrinsics_depth).reshape(-1,3)\n", + "cloud = b.unproject_depth_jit(depth, intrinsics_depth).reshape(-1, 3)\n", "b.show_cloud(\"1\", cloud * 4.0)\n", "\n", - "cloud = b.unproject_depth_jit(depths[T], intrinsics_depth).reshape(-1,3)\n", + "cloud = b.unproject_depth_jit(depths[T], intrinsics_depth).reshape(-1, 3)\n", "# b.show_cloud(\"2\", cloud * 4.0, color=b.RED)" ] }, @@ -559,17 +654,20 @@ "source": [ "rasterizer = intrinsics_to_rasterizer(intrinsics_depth, jnp.eye(4))\n", "\n", - "T =0\n", + "T = 0\n", "# means3D a= torch.tensor(torch.rand((1000,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "cloud = b.unproject_depth_jit(depths[T], intrinsics_depth).reshape(-1,3)\n", - "cloud_transformed = b.apply_transform(cloud, poses[T]) \n", + "cloud = b.unproject_depth_jit(depths[T], intrinsics_depth).reshape(-1, 3)\n", + "cloud_transformed = b.apply_transform(cloud, poses[T])\n", "means3D = torch.tensor(np.array(cloud_transformed), requires_grad=True, device=device)\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-15.0 * torch.ones((N,1)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-15.0 * torch.ones((N, 1)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", "\n", "\n", "T = 30\n", @@ -577,22 +675,25 @@ " matrix=torch.tensor(np.array(poses[T]).transpose(), device=device)\n", ")\n", "\n", - "expanded_color = torch.concatenate([\n", - " colors_precomp,\n", - " means3D[:,2:3],\n", - " torch.ones((means3D.shape[0],1), requires_grad=True, device=device)\n", - "],axis=-1)\n", - "\n", - "color,_ = rasterizer(\n", - " means3D = pose.inverse().transform_points(means3D),\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = expanded_color,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + "expanded_color = torch.concatenate(\n", + " [\n", + " colors_precomp,\n", + " means3D[:, 2:3],\n", + " torch.ones((means3D.shape[0], 1), requires_grad=True, device=device),\n", + " ],\n", + " axis=-1,\n", + ")\n", + "\n", + "color, _ = rasterizer(\n", + " means3D=pose.inverse().transform_points(means3D),\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=expanded_color,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", ")\n", - "b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0)" + "b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0)" ] }, { @@ -614,45 +715,72 @@ } ], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [means3D], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [rotations], 'lr': 0.1, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", - "\n", - "for T in range(0,30,5):\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [means3D], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [scales], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [rotations], \"lr\": 0.1, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", + "\n", + "for T in range(0, 30, 5):\n", " for _ in range(30):\n", - " color_gt = (torch.tensor(np.array(b.utils.resize(colors[T],intrinsics_depth.height, intrinsics_depth.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - " depth_gt = (torch.tensor(np.array(b.utils.resize(depths[T],intrinsics_depth.height, intrinsics_depth.width)))).cuda()\n", + " color_gt = (\n", + " torch.tensor(\n", + " np.array(\n", + " b.utils.resize(\n", + " colors[T], intrinsics_depth.height, intrinsics_depth.width\n", + " )\n", + " )[..., :3]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + " ).cuda()\n", + " depth_gt = (\n", + " torch.tensor(\n", + " np.array(\n", + " b.utils.resize(\n", + " depths[T], intrinsics_depth.height, intrinsics_depth.width\n", + " )\n", + " )\n", + " )\n", + " ).cuda()\n", "\n", " pose = pytorch3d.transforms.Transform3d(\n", " matrix=torch.tensor(np.array(poses[T]).transpose(), device=device)\n", " )\n", "\n", - " newmeans3d = pose.inverse().transform_points(means3D)\n", + " newmeans3d = pose.inverse().transform_points(means3D)\n", "\n", - " expanded_color = torch.concatenate([\n", - " colors_precomp,\n", - " newmeans3d[:,2:3],\n", - " torch.ones((means3D.shape[0],1), requires_grad=True, device=device)\n", - " ],axis=-1)\n", - "\n", - " color,_ = rasterizer(\n", - " means3D = newmeans3d,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = expanded_color,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + " expanded_color = torch.concatenate(\n", + " [\n", + " colors_precomp,\n", + " newmeans3d[:, 2:3],\n", + " torch.ones((means3D.shape[0], 1), requires_grad=True, device=device),\n", + " ],\n", + " axis=-1,\n", + " )\n", + "\n", + " color, _ = rasterizer(\n", + " means3D=newmeans3d,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=expanded_color,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", " )\n", "\n", " optimizer.zero_grad()\n", - " mask = (color[4,...] > 0.99)\n", - " loss = (0.5 * torch.abs(color_gt - color[:3,...]).mean(0) + torch.abs(depth_gt - color[3,...]))[mask].mean()\n", - " loss = torch.abs(color_gt - color[:3,...]).mean()\n", + " mask = color[4, ...] > 0.99\n", + " loss = (\n", + " 0.5 * torch.abs(color_gt - color[:3, ...]).mean(0)\n", + " + torch.abs(depth_gt - color[3, ...])\n", + " )[mask].mean()\n", + " loss = torch.abs(color_gt - color[:3, ...]).mean()\n", " loss.backward()\n", " optimizer.step()\n", " print(loss)" @@ -683,24 +811,27 @@ " matrix=torch.tensor(np.array(poses[T]).transpose(), device=device)\n", ")\n", "\n", - "newmeans3d = pose.inverse().transform_points(means3D)\n", - "\n", - "expanded_color = torch.concatenate([\n", - " colors_precomp,\n", - " newmeans3d[:,2:3],\n", - " torch.ones((means3D.shape[0],1), requires_grad=True, device=device)\n", - "],axis=-1)\n", - "\n", - "color,_ = rasterizer(\n", - " means3D = newmeans3d,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = expanded_color,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + "newmeans3d = pose.inverse().transform_points(means3D)\n", + "\n", + "expanded_color = torch.concatenate(\n", + " [\n", + " colors_precomp,\n", + " newmeans3d[:, 2:3],\n", + " torch.ones((means3D.shape[0], 1), requires_grad=True, device=device),\n", + " ],\n", + " axis=-1,\n", ")\n", - "b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy()[...,:3] * 255.0)" + "\n", + "color, _ = rasterizer(\n", + " means3D=newmeans3d,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=expanded_color,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", + ")\n", + "b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy()[..., :3] * 255.0)" ] }, { @@ -709,10 +840,11 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy(),\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.004\n", + " size=0.004,\n", ")" ] }, @@ -743,9 +875,9 @@ } ], "source": [ - "fig,ax = plt.subplots(2)\n", - "ax[0].matshow(color[3,...].cpu().detach().numpy())\n", - "ax[1].matshow(color[4,...].cpu().detach().numpy())" + "fig, ax = plt.subplots(2)\n", + "ax[0].matshow(color[3, ...].cpu().detach().numpy())\n", + "ax[1].matshow(color[4, ...].cpu().detach().numpy())" ] }, { @@ -782,9 +914,7 @@ } ], "source": [ - "plt.matshow(\n", - " color[4, ...].cpu().detach().numpy() \n", - ")\n", + "plt.matshow(color[4, ...].cpu().detach().numpy())\n", "plt.colorbar()" ] }, @@ -794,10 +924,11 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy(),\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.004\n", + " size=0.004,\n", ")" ] }, @@ -820,46 +951,69 @@ ], "source": [ "import pytorch3d\n", - "color_gt = (torch.tensor(np.array(b.utils.resize(colors[20],intrinsics.height, intrinsics.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "\n", + "color_gt = (\n", + " torch.tensor(\n", + " np.array(b.utils.resize(colors[20], intrinsics.height, intrinsics.width))[\n", + " ..., :3\n", + " ]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + ").cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy() * 255.0\n", + ")\n", "\n", "pos = torch.tensor([0.0, 0.0, 0.0], device=device, requires_grad=True)\n", "quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, requires_grad=True)\n", "\n", - "optimizer = torch.optim.Adam([\n", - " # {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [pos], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " # {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " # {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", + " {\"params\": [pos], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", + " # {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", " pose = pytorch3d.transforms.Transform3d(\n", - " matrix=torch.vstack([torch.hstack([pytorch3d.transforms.quaternion_to_matrix(quat), pos.reshape(3,1)]), torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)])\n", + " matrix=torch.vstack(\n", + " [\n", + " torch.hstack(\n", + " [pytorch3d.transforms.quaternion_to_matrix(quat), pos.reshape(3, 1)]\n", + " ),\n", + " torch.tensor([0.0, 0.0, 0.0, 1.0], device=device),\n", + " ]\n", + " )\n", " )\n", " pose.transform_points(means3D)\n", - " color,_ = rasterizer(\n", - " means3D = pose.transform_points(means3D),\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=pose.transform_points(means3D),\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -896,34 +1050,40 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", - " {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", + " {\"params\": [scales], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [rotations], \"lr\": 1.0, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales).repeat((1,3)),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales).repeat((1, 3)),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -947,6 +1107,7 @@ ], "source": [ "from bayes3d.viz.open3dviz import Open3DVisualizer\n", + "\n", "visualizer = Open3DVisualizer(intrinsics)" ] }, @@ -970,12 +1131,18 @@ "source": [ "import os\n", "import open3d as o3d\n", + "\n", "model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", "i = 4\n", "mesh_path = os.path.join(model_dir, b.utils.ycb_loader.MODEL_NAMES[i], \"textured.obj\")\n", - "mesh = o3d.io.read_triangle_model(mesh_path)\n", + "mesh = o3d.io.read_triangle_model(mesh_path)\n", "visualizer.render.scene.clear_geometry()\n", - "pose = b.distributions.gaussian_vmf(jax.random.PRNGKey(1000),b.transform_from_pos(jnp.array([0.0, 0.0, 0.3])), 0.01, 1.0)\n", + "pose = b.distributions.gaussian_vmf(\n", + " jax.random.PRNGKey(1000),\n", + " b.transform_from_pos(jnp.array([0.0, 0.0, 0.3])),\n", + " 0.01,\n", + " 1.0,\n", + ")\n", "mesh.meshes[0].mesh.transform(pose)\n", "visualizer.render.scene.add_model(f\"m\", mesh)\n", "mesh.meshes[0].mesh.transform(jnp.linalg.inv(pose))\n", @@ -1015,37 +1182,44 @@ } ], "source": [ - "color_gt = (torch.tensor(np.array(rgbd.rgb)[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "color_gt = (torch.tensor(np.array(rgbd.rgb)[..., :3]).permute(2, 0, 1) / 255.0).cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy() * 255.0\n", + ")\n", "gt_viz\n", "\n", "# means3D = torch.tensor(torch.rand((1000,3))-0.5 + torch.tensor([0.0, 0.0, 1.0]), requires_grad=True, device=device)\n", - "cloud = np.array(b.unproject_depth_jit(rgbd.depth, rgbd.intrinsics).reshape(-1,3))\n", - "cloud = cloud[cloud[:,2] < intrinsics.far,:]\n", + "cloud = np.array(b.unproject_depth_jit(rgbd.depth, rgbd.intrinsics).reshape(-1, 3))\n", + "cloud = cloud[cloud[:, 2] < intrinsics.far, :]\n", "# choice = jax.random.choice(jax.random.PRNGKey(1000), cloud.shape[0], shape=(100,), replace=False)\n", "# cloud = cloud[choice]\n", "means3D = torch.tensor(cloud, requires_grad=True, device=device)\n", "\n", - "means2D = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\") + 0\n", + "means2D = (\n", + " torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device=\"cuda\")\n", + " + 0\n", + ")\n", "N = means3D.shape[0]\n", - "opacity = torch.tensor(torch.ones((N,1)), requires_grad=True, device=device)\n", - "colors_precomp = torch.tensor(torch.rand((N,3)), requires_grad=True, device=device)\n", - "scales = torch.tensor(-12.0 * torch.ones((N,3)), requires_grad=True, device=device)\n", - "rotations = torch.tensor(-10.0 * torch.ones((N,4)), requires_grad=True, device=device)\n", - "\n", - "color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = torch.sigmoid(colors_precomp),\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + "opacity = torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device)\n", + "colors_precomp = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "scales = torch.tensor(-12.0 * torch.ones((N, 3)), requires_grad=True, device=device)\n", + "rotations = torch.tensor(-10.0 * torch.ones((N, 4)), requires_grad=True, device=device)\n", + "\n", + "color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=torch.sigmoid(colors_precomp),\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", ")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -1073,33 +1247,39 @@ } ], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", - " {'params': [scales], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [rotations], 'lr': 0.1, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [colors_precomp], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [means3D], 'lr': 0.05, \"name\": \"pos\"},\n", + " {\"params\": [scales], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [rotations], \"lr\": 0.1, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " color,_ = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = torch.sigmoid(colors_precomp),\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=torch.sigmoid(colors_precomp),\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -1108,13 +1288,15 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy(),\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.001)\n", - " # color=jnp.transpose(color.reshape(-1,3))/255.0, \n", + " size=0.001,\n", + ")\n", + "# color=jnp.transpose(color.reshape(-1,3))/255.0,\n", "\n", - " # size=0.003)" + "# size=0.003)" ] }, { @@ -1191,13 +1373,15 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy() * 0.3,\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.002)\n", - " # color=jnp.transpose(color.reshape(-1,3))/255.0, \n", + " size=0.002,\n", + ")\n", + "# color=jnp.transpose(color.reshape(-1,3))/255.0,\n", "\n", - " # size=0.003)" + "# size=0.003)" ] }, { @@ -1239,48 +1423,70 @@ } ], "source": [ - "color_gt = (torch.tensor(np.array(b.utils.resize(colors[20],intrinsics.height, intrinsics.width))[...,:3]).permute(2,0,1) / 255.0).cuda()\n", - "gt_viz = b.get_rgb_image(torch.permute(color_gt, (1,2,0)).cpu().detach().numpy() * 255.0)\n", + "color_gt = (\n", + " torch.tensor(\n", + " np.array(b.utils.resize(colors[20], intrinsics.height, intrinsics.width))[\n", + " ..., :3\n", + " ]\n", + " ).permute(2, 0, 1)\n", + " / 255.0\n", + ").cuda()\n", + "gt_viz = b.get_rgb_image(\n", + " torch.permute(color_gt, (1, 2, 0)).cpu().detach().numpy() * 255.0\n", + ")\n", "\n", "\n", "pos = torch.tensor([0.0, 0.0, 0.0], device=device, requires_grad=True)\n", "quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, requires_grad=True)\n", "\n", "\n", - "optimizer = torch.optim.Adam([\n", - " # {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [pos], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.01, \"name\": \"pos\"},\n", - " # {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", - " # {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", - " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " # {'params': [colors_precomp], 'lr': 0.01, \"name\": \"pos\"},\n", + " {\"params\": [pos], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.01, \"name\": \"pos\"},\n", + " # {'params': [scales], 'lr': 0.001, \"name\": \"pos\"},\n", + " # {'params': [rotations], 'lr': 1.0, \"name\": \"pos\"},\n", + " # {'params': [opacity], 'lr': 0.05, \"name\": \"pos\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", " pose = pytorch3d.transforms.Transform3d(\n", - " matrix=torch.vstack([torch.hstack([pytorch3d.transforms.quaternion_to_matrix(quat), pos.reshape(3,1)]), torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)])\n", + " matrix=torch.vstack(\n", + " [\n", + " torch.hstack(\n", + " [pytorch3d.transforms.quaternion_to_matrix(quat), pos.reshape(3, 1)]\n", + " ),\n", + " torch.tensor([0.0, 0.0, 0.0, 1.0], device=device),\n", + " ]\n", + " )\n", " )\n", " pose.transform_points(means3D)\n", - " color,_ = rasterizer(\n", - " means3D = pose.transform_points(means3D),\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = colors_precomp,\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " color, _ = rasterizer(\n", + " means3D=pose.transform_points(means3D),\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=colors_precomp,\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " optimizer.zero_grad()\n", " loss = torch.abs(color_gt - color).mean()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "b.hstack_images([\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", + " ]\n", + ")" ] }, { @@ -1302,8 +1508,8 @@ ], "source": [ "b.overlay_image(\n", - " b.get_rgb_image(torch.permute(color, (1,2,0)).cpu().detach().numpy() * 255.0),\n", - " gt_viz\n", + " b.get_rgb_image(torch.permute(color, (1, 2, 0)).cpu().detach().numpy() * 255.0),\n", + " gt_viz,\n", ")" ] }, @@ -1350,13 +1556,15 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"obs\", \n", + "b.show_cloud(\n", + " \"obs\",\n", " means3D.cpu().detach().numpy() * 0.3,\n", " color=jnp.transpose(torch.sigmoid(colors_precomp).cpu().detach().numpy()),\n", - " size=0.002)\n", - " # color=jnp.transpose(color.reshape(-1,3))/255.0, \n", + " size=0.002,\n", + ")\n", + "# color=jnp.transpose(color.reshape(-1,3))/255.0,\n", "\n", - " # size=0.003)" + "# size=0.003)" ] }, { diff --git a/scripts/experiments/gaussian_splatting/banana_tracking_with_spheres.ipynb b/scripts/experiments/gaussian_splatting/banana_tracking_with_spheres.ipynb index 2241656d..0df69818 100644 --- a/scripts/experiments/gaussian_splatting/banana_tracking_with_spheres.ipynb +++ b/scripts/experiments/gaussian_splatting/banana_tracking_with_spheres.ipynb @@ -29,8 +29,7 @@ "\n", "def open3d_mesh_to_trimesh(mesh):\n", " return trimesh.Trimesh(\n", - " vertices=np.asarray(mesh.vertices),\n", - " faces=np.asarray(mesh.triangles)\n", + " vertices=np.asarray(mesh.vertices), faces=np.asarray(mesh.triangles)\n", " )" ] }, @@ -42,12 +41,9 @@ "outputs": [], "source": [ "from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer\n", + "\n", "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=200.0, fy=200.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.1, far=3.5\n", + " height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.1, far=3.5\n", ")\n", "jax_renderer = JaxRenderer(intrinsics)" ] @@ -60,11 +56,12 @@ "outputs": [], "source": [ "import os\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 10\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "mesh = b.utils.scale_mesh(m, 1.0/100.0)" + "mesh = b.utils.scale_mesh(m, 1.0 / 100.0)" ] }, { @@ -75,13 +72,14 @@ "outputs": [], "source": [ "def xfm_points(points, matrix):\n", - " points2 = jnp.concatenate([points, jnp.ones((*points.shape[:-1],1))], axis=-1)\n", + " points2 = jnp.concatenate([points, jnp.ones((*points.shape[:-1], 1))], axis=-1)\n", " return jnp.matmul(points2, matrix.T)\n", "\n", + "\n", "# projection_matrix = b.camera._open_gl_projection_matrix(\n", - "# intrinsics.height, intrinsics.width, \n", - "# intrinsics.fx, intrinsics.fy, \n", - "# intrinsics.cx, intrinsics.cy, \n", + "# intrinsics.height, intrinsics.width,\n", + "# intrinsics.fx, intrinsics.fy,\n", + "# intrinsics.cx, intrinsics.cy,\n", "# intrinsics.near, intrinsics.far\n", "# )\n", "# self = jax_renderer\n", @@ -94,18 +92,27 @@ "# posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", "# pos_clip_ja = xfm_points(vertices, final_mtx_proj)\n", "\n", + "\n", "def render(vertices, faces, object_pose, intrinsics):\n", " projection_matrix = b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", " )\n", " final_mtx_proj = projection_matrix @ object_pose\n", - " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", + " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1)\n", " pos_clip_ja = xfm_points(vertices, final_mtx_proj)\n", - " rast_out, rast_out_db = jax_renderer.rasterize(pos_clip_ja[None,...], faces, jnp.array([intrinsics.height, intrinsics.width]))\n", - " gb_pos,_ = jax_renderer.interpolate(posw[None,...], rast_out, faces, rast_out_db, jnp.array([0,1,2,3]))\n", + " rast_out, rast_out_db = jax_renderer.rasterize(\n", + " pos_clip_ja[None, ...], faces, jnp.array([intrinsics.height, intrinsics.width])\n", + " )\n", + " gb_pos, _ = jax_renderer.interpolate(\n", + " posw[None, ...], rast_out, faces, rast_out_db, jnp.array([0, 1, 2, 3])\n", + " )\n", " mask = rast_out[..., -1] > 0\n", " shape_keep = gb_pos.shape\n", " gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])\n", @@ -113,9 +120,10 @@ " depth = xfm_points(gb_pos, object_pose)\n", " depth = depth.reshape(shape_keep)[..., 2] * -1\n", " return depth * mask, mask\n", - " \n", + "\n", + "\n", "jax.clear_caches()\n", - "render_jit = jax.jit(render)\n" + "render_jit = jax.jit(render)" ] }, { @@ -139,8 +147,10 @@ "source": [ "gt_position = jnp.array([0.0, 0.0, 2.8])\n", "gt_pose = b.transform_from_pos(gt_position)\n", - "gt_img,gt_mask = render_jit(mesh.vertices, mesh.faces, gt_pose, intrinsics)\n", - "b.hstack_images([b.get_depth_image(gt_img[0,...]),b.get_depth_image(gt_mask[0,...] * 1.0)])" + "gt_img, gt_mask = render_jit(mesh.vertices, mesh.faces, gt_pose, intrinsics)\n", + "b.hstack_images(\n", + " [b.get_depth_image(gt_img[0, ...]), b.get_depth_image(gt_mask[0, ...] * 1.0)]\n", + ")" ] }, { @@ -151,8 +161,14 @@ "outputs": [], "source": [ "def loss(z, gt_img, gt_mask):\n", - " img,_ = render(mesh.vertices, mesh.faces, b.transform_from_pos(jnp.array([0.0, 0.0, z])), intrinsics)\n", - " return jnp.abs((gt_img - img)*gt_mask).mean()\n", + " img, _ = render(\n", + " mesh.vertices,\n", + " mesh.faces,\n", + " b.transform_from_pos(jnp.array([0.0, 0.0, z])),\n", + " intrinsics,\n", + " )\n", + " return jnp.abs((gt_img - img) * gt_mask).mean()\n", + "\n", "\n", "grad_func = jax.jit(jax.value_and_grad(loss, argnums=(0,)))\n", "grad_func_no_jit = jax.value_and_grad(loss, argnums=(0,))" @@ -175,12 +191,12 @@ "source": [ "z = 2.6\n", "from tqdm import tqdm\n", + "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " loss,grad = grad_func(z, gt_img,gt_mask)\n", + " loss, grad = grad_func(z, gt_img, gt_mask)\n", " z = z - 0.1 * grad[0]\n", - " pbar.set_description(f\"{loss} z: {z}\")\n", - "\n" + " pbar.set_description(f\"{loss} z: {z}\")" ] }, { @@ -191,8 +207,9 @@ "outputs": [], "source": [ "def loss(z, gt_img, gt_mask):\n", - " img,_ = render(mesh.vertices, mesh.faces, b.transform_from_pos(z), intrinsics)\n", - " return jnp.abs((gt_img - img)*gt_mask).mean()\n", + " img, _ = render(mesh.vertices, mesh.faces, b.transform_from_pos(z), intrinsics)\n", + " return jnp.abs((gt_img - img) * gt_mask).mean()\n", + "\n", "\n", "grad_func = jax.jit(jax.value_and_grad(loss, argnums=(0,)))\n", "grad_func_no_jit = jax.value_and_grad(loss, argnums=(0,))" @@ -207,17 +224,25 @@ "source": [ "quat = jax.random.uniform(jax.random.PRNGKey(30), (4,))\n", "translation = jnp.array([0.0, 0.0, 2.5])\n", - "poses = [b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), translation)]\n", + "poses = [\n", + " b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), translation)\n", + "]\n", "delta_pose = b.t3d.transform_from_rot_and_pos(\n", - " R.from_euler('zyx', [2.0, 1.1, 3.0], degrees=True).as_matrix(),\n", - " jnp.array([-0.009, 0.005, 0.02])\n", + " R.from_euler(\"zyx\", [2.0, 1.1, 3.0], degrees=True).as_matrix(),\n", + " jnp.array([-0.009, 0.005, 0.02]),\n", ")\n", "num_frames = 200\n", - "for t in range(num_frames-1):\n", + "for t in range(num_frames - 1):\n", " poses.append(poses[-1].dot(delta_pose))\n", "poses = jnp.stack(poses)\n", - "gt_images = [render_jit(mesh.vertices, mesh.faces, poses[i], intrinsics) for i in range(num_frames)]\n", - "viz_gt_images = [b.get_depth_image(gt_images[i][0][0,...],max=intrinsics.far) for i in range(num_frames)]\n", + "gt_images = [\n", + " render_jit(mesh.vertices, mesh.faces, poses[i], intrinsics)\n", + " for i in range(num_frames)\n", + "]\n", + "viz_gt_images = [\n", + " b.get_depth_image(gt_images[i][0][0, ...], max=intrinsics.far)\n", + " for i in range(num_frames)\n", + "]\n", "b.make_gif_from_pil_images(viz_gt_images, \"obs.gif\")" ] }, @@ -228,16 +253,29 @@ "metadata": {}, "outputs": [], "source": [ - "def render_translation_and_quat(translation,quat):\n", - " pose = b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), translation)\n", - " img,_ = render(mesh.vertices, mesh.faces, pose, intrinsics)\n", + "def render_translation_and_quat(translation, quat):\n", + " pose = b.transform_from_rot_and_pos(\n", + " b.quaternion_to_rotation_matrix(quat), translation\n", + " )\n", + " img, _ = render(mesh.vertices, mesh.faces, pose, intrinsics)\n", " return img\n", + "\n", + "\n", "def loss(translation, quat, gt_img, gt_mask):\n", " img = render_translation_and_quat(translation, quat)\n", - " return jnp.abs((gt_img - img)*gt_mask).mean()\n", + " return jnp.abs((gt_img - img) * gt_mask).mean()\n", + "\n", "\n", "render_translation_and_quat_jit = jax.jit(render_translation_and_quat)\n", - "grad_func = jax.jit(jax.value_and_grad(loss, argnums=(0,1,)))\n" + "grad_func = jax.jit(\n", + " jax.value_and_grad(\n", + " loss,\n", + " argnums=(\n", + " 0,\n", + " 1,\n", + " ),\n", + " )\n", + ")" ] }, { @@ -260,7 +298,7 @@ } ], "source": [ - "grad_func(translation, quat, gt_images[0][0], gt_images[0][1])\n" + "grad_func(translation, quat, gt_images[0][0], gt_images[0][1])" ] }, { @@ -305,9 +343,13 @@ "pbar = tqdm(range(len(gt_images)))\n", "for timestep in pbar:\n", " for _ in range(15):\n", - " loss, (grad_translation, grad_quat, ) = grad_func(\n", - " translation, quat, gt_images[timestep][0], gt_images[timestep][1]\n", - " )\n", + " (\n", + " loss,\n", + " (\n", + " grad_translation,\n", + " grad_quat,\n", + " ),\n", + " ) = grad_func(translation, quat, gt_images[timestep][0], gt_images[timestep][1])\n", " translation = translation - 0.5 * grad_translation\n", " quat = quat - 1.0 * grad_quat\n", " pbar.set_description(f\"{loss}\")\n", @@ -322,7 +364,8 @@ "outputs": [], "source": [ "inferred_viz_images = [\n", - " b.get_depth_image(images_over_time[i][0,...],max=intrinsics.far) for i in range(num_frames)\n", + " b.get_depth_image(images_over_time[i][0, ...], max=intrinsics.far)\n", + " for i in range(num_frames)\n", "]\n", "all_imgs = [\n", " b.hstack_images([viz_gt_images[timestep], inferred_viz_images[timestep]])\n", @@ -389,8 +432,8 @@ ], "source": [ "pose = b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), translation)\n", - "img,_ = render(mesh.vertices, mesh.faces, pose, intrinsics)\n", - "b.get_depth_image(img[0,...],max=intrinsics.far)" + "img, _ = render(mesh.vertices, mesh.faces, pose, intrinsics)\n", + "b.get_depth_image(img[0, ...], max=intrinsics.far)" ] }, { @@ -421,7 +464,7 @@ "metadata": {}, "outputs": [], "source": [ - "z = jnp.array([0.0, 0.0,2.5])\n" + "z = jnp.array([0.0, 0.0, 2.5])" ] }, { @@ -440,9 +483,10 @@ ], "source": [ "from tqdm import tqdm\n", + "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " loss,grad = grad_func(z, gt_img,gt_mask)\n", + " loss, grad = grad_func(z, gt_img, gt_mask)\n", " z = z - 0.01 * grad[0]\n", " pbar.set_description(f\"{loss} z: {z}\")" ] @@ -512,10 +556,8 @@ "source": [ "pose = b.transform_from_pos(jnp.array([0.0, 0.0, z]))\n", "img = jax_renderer.render(mesh.vertices, mesh.faces, pose, intrinsics, None)\n", - "print(\n", - " jnp.abs(gt_img - img).mean()\n", - ")\n", - "b.get_depth_image(img[0,...])\n" + "print(jnp.abs(gt_img - img).mean())\n", + "b.get_depth_image(img[0, ...])" ] }, { @@ -624,11 +666,11 @@ "base_faces = jnp.array(mesh.faces)\n", "\n", "b.clear()\n", - "b.show_cloud(\"1\",mesh.vertices)\n", + "b.show_cloud(\"1\", mesh.vertices)\n", "\n", "data = jnp.load(\"gaussians_banana.npz\")\n", "print(data.files)\n", - "means = data[\"mus\"]\n", + "means = data[\"mus\"]\n", "covs = data[\"choleskys\"]\n", "print(covs.shape)" ] @@ -651,16 +693,22 @@ "source": [ "def get_transformed_vertices(center, cov):\n", " return base_vertices @ (3.0 * cov.T) + center\n", - "all_vertices = jax.vmap(get_transformed_vertices, in_axes=(0,0))(means, covs)\n", + "\n", + "\n", + "all_vertices = jax.vmap(get_transformed_vertices, in_axes=(0, 0))(means, covs)\n", "print(all_vertices.shape)\n", - "all_faces = (jnp.arange(all_vertices.shape[0]) * base_vertices.shape[0])[:,None, None] + base_faces[None,...]\n", + "all_faces = (jnp.arange(all_vertices.shape[0]) * base_vertices.shape[0])[\n", + " :, None, None\n", + "] + base_faces[None, ...]\n", "print(all_faces.shape)\n", "\n", "# mesh = trimesh.Trimesh(vertices=base_vertices.reshape(-1,3), faces=base_faces.reshape(-1,3))\n", - "mesh = trimesh.Trimesh(vertices=all_vertices.reshape(-1,3), faces=all_faces.reshape(-1,3))\n", + "mesh = trimesh.Trimesh(\n", + " vertices=all_vertices.reshape(-1, 3), faces=all_faces.reshape(-1, 3)\n", + ")\n", "\n", "b.clear()\n", - "b.show_trimesh(\"1\",mesh)" + "b.show_trimesh(\"1\", mesh)" ] }, { @@ -715,11 +763,12 @@ "outputs": [], "source": [ "import os\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 10\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "m = b.utils.scale_mesh(m, 1.0/100.0)" + "m = b.utils.scale_mesh(m, 1.0 / 100.0)" ] }, { @@ -729,7 +778,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/100.0)" + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 100.0)" ] }, { @@ -740,9 +789,9 @@ "outputs": [], "source": [ "pose = b.transform_from_pos(jnp.array([0.0, 0.0, 0.5]))\n", - "img = b.RENDERER.render(pose[None,...], jnp.array([0]))\n", - "print(img[...,2])\n", - "b.get_depth_image(img[...,2])" + "img = b.RENDERER.render(pose[None, ...], jnp.array([0]))\n", + "print(img[..., 2])\n", + "b.get_depth_image(img[..., 2])" ] }, { @@ -761,17 +810,24 @@ ], "source": [ "jax_renderer = JaxRenderer(intrinsics)\n", + "\n", + "\n", "def xfm_points(points, matrix):\n", - " points = jnp.concatenate([points, jnp.ones((*points.shape[:-1],1))], axis=-1)\n", + " points = jnp.concatenate([points, jnp.ones((*points.shape[:-1], 1))], axis=-1)\n", " return jnp.matmul(points, matrix.T)\n", "\n", + "\n", "def render(vertices, faces, object_pose, projection_matrix):\n", " final_mtx_proj = projection_matrix @ pose\n", - " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", + " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1)\n", " pos_clip_ja = xfm_points(vertices, final_mtx_proj)\n", "\n", - " rast_out, rast_out_db = jax_renderer.rasterize(pos_clip_ja[None,...], faces, jnp.array([intrinsics.height, intrinsics.width]))\n", - " gb_pos,_ = jax_renderer.interpolate(posw[None,...], rast_out, faces, rast_out_db, jnp.array([0,1,2,3]))\n", + " rast_out, rast_out_db = jax_renderer.rasterize(\n", + " pos_clip_ja[None, ...], faces, jnp.array([intrinsics.height, intrinsics.width])\n", + " )\n", + " gb_pos, _ = jax_renderer.interpolate(\n", + " posw[None, ...], rast_out, faces, rast_out_db, jnp.array([0, 1, 2, 3])\n", + " )\n", " mask = rast_out[..., -1] > 0\n", " shape_keep = gb_pos.shape\n", " gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])\n", @@ -779,7 +835,7 @@ " depth = xfm_points(gb_pos, pose)\n", " depth = depth.reshape(shape_keep)[..., 2] * -1\n", " return -depth * mask + intrinsics.far * (1.0 - mask)\n", - " # rast_out, _ = jax_renderer.rasterize(new_cloud[None,...], jnp.array(b.RENDERER.meshes[0].faces), jnp.array([intrinsics.height, intrinsics.width]))\n" + " # rast_out, _ = jax_renderer.rasterize(new_cloud[None,...], jnp.array(b.RENDERER.meshes[0].faces), jnp.array([intrinsics.height, intrinsics.width]))" ] }, { @@ -802,14 +858,18 @@ ], "source": [ "proj = b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", ")\n", "%%time\n", "img = render(b.RENDERER.meshes[0].vertices, b.RENDERER.meshes[0].faces, pose, proj)\n", - "b.get_depth_image(img[0,...])" + "b.get_depth_image(img[0, ...])" ] }, { @@ -853,8 +913,8 @@ } ], "source": [ - "cloud =b.apply_transform_jit(b.RENDERER.meshes[0].vertices, pose)\n", - "b.show_cloud(\"1\",cloud)\n", + "cloud = b.apply_transform_jit(b.RENDERER.meshes[0].vertices, pose)\n", + "b.show_cloud(\"1\", cloud)\n", "cloudw = b.add_homogenous_ones(cloud)\n", "print(cloudw)" ] @@ -883,15 +943,19 @@ "source": [ "proj = b.camera.getProjectionMatrix(intrinsics)\n", "proj = b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", ")\n", "new_cloud = cloudw @ proj.T\n", "print(new_cloud.shape)\n", "print(new_cloud)\n", - "b.show_cloud(\"2\",new_cloud[:,:3] / new_cloud[:,3][...,None],color=b.RED)" + "b.show_cloud(\"2\", new_cloud[:, :3] / new_cloud[:, 3][..., None], color=b.RED)" ] }, { @@ -936,7 +1000,7 @@ } ], "source": [ - "rast_out[0,...,-1]" + "rast_out[0, ..., -1]" ] }, { @@ -1033,10 +1097,14 @@ } ], "source": [ - "rast_out, _ = jax_renderer.rasterize(new_cloud[None,...], jnp.array(b.RENDERER.meshes[0].faces), jnp.array([intrinsics.height, intrinsics.width]))\n", + "rast_out, _ = jax_renderer.rasterize(\n", + " new_cloud[None, ...],\n", + " jnp.array(b.RENDERER.meshes[0].faces),\n", + " jnp.array([intrinsics.height, intrinsics.width]),\n", + ")\n", "print(rast_out.sum())\n", - "plt.matshow(rast_out[0,...,-1])\n", - "print(jnp.unique(rast_out[0,...,-1]))\n", + "plt.matshow(rast_out[0, ..., -1])\n", + "print(jnp.unique(rast_out[0, ..., -1]))\n", "plt.colorbar()" ] }, @@ -1074,15 +1142,19 @@ "metadata": {}, "outputs": [], "source": [ - "x=1.0\n", - "n=0.01\n", - "f=5.0\n", - "proj = np.array([[n/x, 0, 0, 0],\n", - " [ 0, n/x, 0, 0],\n", - " [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],\n", - " [ 0, 0, -1, 0]])\n", + "x = 1.0\n", + "n = 0.01\n", + "f = 5.0\n", + "proj = np.array(\n", + " [\n", + " [n / x, 0, 0, 0],\n", + " [0, n / x, 0, 0],\n", + " [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],\n", + " [0, 0, -1, 0],\n", + " ]\n", + ")\n", "new_cloud = b.apply_transform_jit(-cloud, proj.T)\n", - "b.show_cloud(\"2\",new_cloud,color=b.RED)" + "b.show_cloud(\"2\", new_cloud, color=b.RED)" ] }, { diff --git a/scripts/experiments/gaussian_splatting/debug_renderer_jit.ipynb b/scripts/experiments/gaussian_splatting/debug_renderer_jit.ipynb index 6278d3dd..8ce82ff6 100644 --- a/scripts/experiments/gaussian_splatting/debug_renderer_jit.ipynb +++ b/scripts/experiments/gaussian_splatting/debug_renderer_jit.ipynb @@ -22,10 +22,11 @@ "import jax.numpy as jnp\n", "import jax\n", "import matplotlib.pyplot as plt\n", + "\n", + "\n", "def open3d_mesh_to_trimesh(mesh):\n", " return trimesh.Trimesh(\n", - " vertices=np.asarray(mesh.vertices),\n", - " faces=np.asarray(mesh.triangles)\n", + " vertices=np.asarray(mesh.vertices), faces=np.asarray(mesh.triangles)\n", " )" ] }, @@ -37,12 +38,9 @@ "outputs": [], "source": [ "from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer\n", + "\n", "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=200.0, fy=200.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.1, far=3.5\n", + " height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.1, far=3.5\n", ")\n", "jax_renderer = JaxRenderer(intrinsics)" ] @@ -55,11 +53,12 @@ "outputs": [], "source": [ "import os\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 10\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "mesh = b.utils.scale_mesh(m, 1.0/100.0)" + "mesh = b.utils.scale_mesh(m, 1.0 / 100.0)" ] }, { @@ -70,14 +69,19 @@ "outputs": [], "source": [ "def xfm_points(points, matrix):\n", - " points = jnp.concatenate([points, jnp.ones((*points.shape[:-1],1))], axis=-1)\n", + " points = jnp.concatenate([points, jnp.ones((*points.shape[:-1], 1))], axis=-1)\n", " return jnp.matmul(points, matrix.T)\n", "\n", + "\n", "projection_matrix = b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", ")\n", "self = jax_renderer\n", "gt_z = 2.0\n", @@ -87,33 +91,40 @@ "vertices = mesh.vertices\n", "faces = mesh.faces\n", "final_mtx_proj = projection_matrix @ object_pose\n", - "posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", + "posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1)\n", "pos_clip_ja = xfm_points(vertices, final_mtx_proj)\n", "\n", "\n", "def render(z):\n", " object_pose = b.transform_from_pos(jnp.array([0.0, 0.0, z]))\n", " final_mtx_proj = projection_matrix @ object_pose\n", - " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", + " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1)\n", " pos_clip_ja = xfm_points(vertices, final_mtx_proj)\n", - " rast_out, rast_out_db = jax_renderer.rasterize(pos_clip_ja[None,...], faces, jnp.array([intrinsics.height, intrinsics.width]))\n", + " rast_out, rast_out_db = jax_renderer.rasterize(\n", + " pos_clip_ja[None, ...], faces, jnp.array([intrinsics.height, intrinsics.width])\n", + " )\n", " # return rast_out[...,3],rast_out[...,3], rast_out.sum(), rast_out\n", - " gb_pos,_ = jax_renderer.interpolate(posw[None,...], rast_out, faces, rast_out_db, jnp.array([0,1,2,3]))\n", + " gb_pos, _ = jax_renderer.interpolate(\n", + " posw[None, ...], rast_out, faces, rast_out_db, jnp.array([0, 1, 2, 3])\n", + " )\n", " mask = rast_out[..., -1] > 0\n", " shape_keep = gb_pos.shape\n", " gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])\n", " gb_pos = gb_pos[..., :3]\n", " depth = xfm_points(gb_pos, pose)\n", " depth = depth.reshape(shape_keep)[..., 2] * -1\n", - " return depth, mask, mask.sum(), rast_out[...,3]\n", - " \n", + " return depth, mask, mask.sum(), rast_out[..., 3]\n", + "\n", + "\n", "jax.clear_caches()\n", "render_jit = jax.jit(render)\n", "\n", + "\n", "def loss(z, gt_img):\n", - " a,b,c,d = render(z)\n", + " a, b, c, d = render(z)\n", " return jnp.abs(gt_img - a).mean()\n", "\n", + "\n", "grad_func = jax.value_and_grad(loss, argnums=(0,))\n", "grad_func_jit = jax.jit(grad_func)" ] @@ -148,11 +159,13 @@ "source": [ "gt_z = 2.2\n", "z = 2.5\n", - "gt_img,gt_mask,tmp,_ = render_jit(gt_z)\n", + "gt_img, gt_mask, tmp, _ = render_jit(gt_z)\n", "print(\"loss \", loss(gt_z, gt_img))\n", "print(\"loss \", grad_func(gt_z, gt_img))\n", "img, mask, _, _ = render_jit(z)\n", - "b.hstack_images([b.get_depth_image(gt_img[0,...]), b.get_depth_image(img[0,...]*1.0)])\n" + "b.hstack_images(\n", + " [b.get_depth_image(gt_img[0, ...]), b.get_depth_image(img[0, ...] * 1.0)]\n", + ")" ] }, { @@ -183,18 +196,24 @@ { "cell_type": "code", "execution_count": 7, + "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "outputs": [], "source": [ "def xfm_points(points, matrix):\n", - " points2 = jnp.concatenate([points, jnp.ones((*points.shape[:-1],1))], axis=-1)\n", + " points2 = jnp.concatenate([points, jnp.ones((*points.shape[:-1], 1))], axis=-1)\n", " return jnp.matmul(points2, matrix.T)\n", "\n", + "\n", "projection_matrix = b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", ")\n", "self = jax_renderer\n", "gt_z = 2.0\n", @@ -204,18 +223,21 @@ "vertices = mesh.vertices\n", "faces = mesh.faces\n", "final_mtx_proj = projection_matrix @ object_pose\n", - "posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", + "posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1)\n", "pos_clip_ja = xfm_points(vertices, final_mtx_proj)\n", "\n", "\n", "def render(z):\n", " object_pose = b.transform_from_pos(jnp.array([0.0, 0.0, z]))\n", " final_mtx_proj = projection_matrix @ object_pose\n", - " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", + " posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1)\n", " pos_clip_ja = xfm_points(vertices, final_mtx_proj)\n", - " rast_out, rast_out_db = jax_renderer.rasterize(pos_clip_ja[None,...], faces, jnp.array([intrinsics.height, intrinsics.width]))\n", - " return rast_out[...,3], None #rast_out[0,0,...]\n", - " \n", + " rast_out, rast_out_db = jax_renderer.rasterize(\n", + " pos_clip_ja[None, ...], faces, jnp.array([intrinsics.height, intrinsics.width])\n", + " )\n", + " return rast_out[..., 3], None # rast_out[0,0,...]\n", + "\n", + "\n", "jax.clear_caches()\n", "render_jit = jax.jit(render)" ] @@ -223,6 +245,7 @@ { "cell_type": "code", "execution_count": 8, + "id": "acae54e37e7d407bbb7b55eff062a284", "metadata": {}, "outputs": [ { @@ -237,14 +260,15 @@ "source": [ "gt_z = 2.2\n", "z = 2.5\n", - "gt_img,_ = render_jit(gt_z)\n", + "gt_img, _ = render_jit(gt_z)\n", "print(\"loss jit\", jnp.abs(gt_img - render_jit(z)[0]).mean())\n", - "print(\"loss \", jnp.abs(gt_img - render(z)[0]).mean())\n" + "print(\"loss \", jnp.abs(gt_img - render(z)[0]).mean())" ] }, { "cell_type": "code", "execution_count": null, + "id": "9a63283cbaf04dbcab1f6479b197f3a8", "metadata": {}, "outputs": [], "source": [] diff --git a/scripts/experiments/gaussian_splatting/fast_particles.ipynb b/scripts/experiments/gaussian_splatting/fast_particles.ipynb index 6ce7e90c..71ed9420 100644 --- a/scripts/experiments/gaussian_splatting/fast_particles.ipynb +++ b/scripts/experiments/gaussian_splatting/fast_particles.ipynb @@ -62,7 +62,9 @@ "import numpy as np\n", "import torch\n", "import imageio\n", - "import bayes3d as b; b = b.bayes3d\n", + "import bayes3d as b\n", + "\n", + "b = b.bayes3d\n", "from tqdm import tqdm\n", "import jax.numpy as jnp\n", "import pytorch3d.transforms as t3d\n", @@ -102,22 +104,22 @@ "metadata": {}, "outputs": [], "source": [ - "max_iter = 10000\n", - "repeats = 1\n", - "log_interval = 10\n", - "display_interval = None\n", - "display_res = 512\n", - "lr_base = 1e-3\n", - "lr_falloff = 1.0\n", - "nr_base = 1.0\n", - "nr_falloff = 1e-4\n", - "grad_phase_start = 0.5\n", - "resolution = [200,200]\n", - "out_dir = None\n", - "log_fn = None\n", - "mp4save_interval = None\n", - "mp4save_fn = None\n", - "use_opengl = False" + "max_iter = 10000\n", + "repeats = 1\n", + "log_interval = 10\n", + "display_interval = None\n", + "display_res = 512\n", + "lr_base = 1e-3\n", + "lr_falloff = 1.0\n", + "nr_base = 1.0\n", + "nr_falloff = 1e-4\n", + "grad_phase_start = 0.5\n", + "resolution = [200, 200]\n", + "out_dir = None\n", + "log_fn = None\n", + "mp4save_interval = None\n", + "mp4save_fn = None\n", + "use_opengl = False" ] }, { @@ -136,9 +138,9 @@ ], "source": [ "# Misc helpers\n", - "def get_img_with_border(img, border=5, fill='red'):\n", + "def get_img_with_border(img, border=5, fill=\"red\"):\n", " cropped_img = ImageOps.crop(img, border=border)\n", - " return ImageOps.expand(cropped_img, border=border,fill=fill)\n", + " return ImageOps.expand(cropped_img, border=border, fill=fill)\n", "\n", "\n", "def set_seed(seed: int = 0) -> None:\n", @@ -151,6 +153,8 @@ " # Set a fixed value for the hash seed\n", " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", " print(f\"Random seed set as {seed}\")\n", + "\n", + "\n", "set_seed(0)" ] }, @@ -161,7 +165,7 @@ "metadata": {}, "outputs": [], "source": [ - "glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext()" + "glctx = dr.RasterizeGLContext() # if use_opengl else dr.RasterizeCudaContext()" ] }, { @@ -186,18 +190,18 @@ } ], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 14\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "m = b.utils.scale_mesh(m, 1.0/100.0)\n", + "m = b.utils.scale_mesh(m, 1.0 / 100.0)\n", "\n", "# m = b.utils.make_cuboid_mesh(jnp.array([0.5, 0.5, 0.2]))\n", "\n", "vtx_pos = torch.from_numpy(m.vertices.astype(np.float32)).cuda()\n", "pos_idx = torch.from_numpy(m.faces.astype(np.int32)).cuda()\n", - "col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0],3)).astype(np.int32)).cuda()\n", - "vtx_col = torch.from_numpy(np.ones((1,3)).astype(np.float32)).cuda()\n", + "col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0], 3)).astype(np.int32)).cuda()\n", + "vtx_col = torch.from_numpy(np.ones((1, 3)).astype(np.float32)).cuda()\n", "# print(\"Mesh has %d triangles and %d vertices.\" % (pos_idx.shape[0], pos.shape[0]))\n", "print(pos_idx.shape, vtx_pos.shape, col_idx.shape, vtx_col.shape)\n", "print(vtx_pos, vtx_col)" @@ -225,19 +229,25 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=200.0, fy=200.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=5.5\n", + " height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.01, far=5.5\n", ")\n", "\n", - "mvp = proj_cam = torch.tensor(np.array(b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", - ")), device=device, dtype=torch.float32) # model-view-projection transformation\n", + "mvp = proj_cam = torch.tensor(\n", + " np.array(\n", + " b.camera._open_gl_projection_matrix(\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", + " )\n", + " ),\n", + " device=device,\n", + " dtype=torch.float32,\n", + ") # model-view-projection transformation\n", "mvp" ] }, @@ -251,18 +261,31 @@ "def posevec_to_matrix_single(position, quat):\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quat), position.unsqueeze(1)), 1),\n", - " torch.tensor([[0.0, 0.0, 0.0, 1.0]],device=device),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quat),\n", + " position.unsqueeze(1),\n", + " ),\n", + " 1,\n", + " ),\n", + " torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device),\n", " ),\n", " 0,\n", " )\n", "\n", + "\n", "def posevec_to_matrix_batch(positions, quats):\n", " batch_size = positions.shape[0]\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quats), positions.unsqueeze(2)), 2),\n", - " torch.tensor([0.0, 0.0, 0.0, 1.0], device=device).repeat(batch_size,1,1),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quats),\n", + " positions.unsqueeze(2),\n", + " ),\n", + " 2,\n", + " ),\n", + " torch.tensor([0.0, 0.0, 0.0, 1.0], device=device).repeat(batch_size, 1, 1),\n", " ),\n", " 1,\n", " )" @@ -280,24 +303,32 @@ " rot_mtx_44 = posevec_to_matrix_single(pos, quat)\n", "\n", " # preprocess and transform points into clip space\n", - " pos = vtx_pos[None,...].contiguous()\n", - " posw = torch.cat([pos, torch.ones([pos.shape[0], pos.shape[1], 1]).cuda()], axis=2) # (xyz) -> (xyz1)\n", - " transform_mtx = torch.matmul(proj_cam, rot_mtx_44) # transform = projection + pose rotation\n", - " pos_clip_ja = dd.xfm_points(pos, transform_mtx[None,...]) # transform points\n", - " \n", + " pos = vtx_pos[None, ...].contiguous()\n", + " posw = torch.cat(\n", + " [pos, torch.ones([pos.shape[0], pos.shape[1], 1]).cuda()], axis=2\n", + " ) # (xyz) -> (xyz1)\n", + " transform_mtx = torch.matmul(\n", + " proj_cam, rot_mtx_44\n", + " ) # transform = projection + pose rotation\n", + " pos_clip_ja = dd.xfm_points(pos, transform_mtx[None, ...]) # transform points\n", + "\n", " # rasterize and interpolate (in world space)\n", - " rast_out, rast_out_db = dr.rasterize(glctx, pos_clip_ja, pos_idx, resolution=resolution)\n", - " gb_pos, _ = dr.interpolate(posw.contiguous(), rast_out, pos_idx, rast_db=rast_out_db, diff_attrs=\"all\")\n", + " rast_out, rast_out_db = dr.rasterize(\n", + " glctx, pos_clip_ja, pos_idx, resolution=resolution\n", + " )\n", + " gb_pos, _ = dr.interpolate(\n", + " posw.contiguous(), rast_out, pos_idx, rast_db=rast_out_db, diff_attrs=\"all\"\n", + " )\n", "\n", " # Get depth values (in camera space)\n", " gb_pos = gb_pos.contiguous()\n", - " mask= rast_out[...,2] > 0\n", + " mask = rast_out[..., 2] > 0\n", " # return gb_pos[...,2], mask\n", "\n", " shape_keep = gb_pos.shape\n", " gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])\n", " gb_pos = gb_pos[..., :3]\n", - " depth = dd.xfm_points(gb_pos, rot_mtx_44[None,...])\n", + " depth = dd.xfm_points(gb_pos, rot_mtx_44[None, ...])\n", " depth = depth.reshape(shape_keep)[..., 2] * -1\n", "\n", " return depth, mask" @@ -362,32 +393,37 @@ } ], "source": [ - "pos = torch.tensor([0.0, 0.0, 2.3],device=device)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", + "pos = torch.tensor([0.0, 0.0, 2.3], device=device)\n", + "quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", "pose_gt = posevec_to_matrix_single(pos, quat)\n", "obs_depth, mask_gt = render(pos, quat)\n", "viz_gt = b.get_depth_image(jnp.array(obs_depth[0].cpu().numpy()))\n", - "viz_mask_gt = b.get_depth_image(jnp.array(mask_gt[0].cpu().numpy()) * 1.0,max=1.1)\n", - "b.viz.hstack_images([viz_gt,viz_mask_gt]).show()\n", + "viz_mask_gt = b.get_depth_image(jnp.array(mask_gt[0].cpu().numpy()) * 1.0, max=1.1)\n", + "b.viz.hstack_images([viz_gt, viz_mask_gt]).show()\n", "\n", - "pos = torch.tensor([0.0, 0.0,2.1],device=device, requires_grad=True)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device, requires_grad=True)\n", - "rendered_image,_ = render(pos,quat)\n", - "viz_orig = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", + "pos = torch.tensor([0.0, 0.0, 2.1], device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " torch.rand(4, device=device) - 0.5, device=device, requires_grad=True\n", + ")\n", + "rendered_image, _ = render(pos, quat)\n", + "viz_orig = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", "pose_optim = posevec_to_matrix_single(pos, quat)\n", "\n", "b.show_pose(\"1\", pose_gt.detach().cpu().numpy())\n", "b.show_pose(\"2\", pose_optim.detach().cpu().numpy())\n", "\n", - "optimizer = torch.optim.SGD([\n", - " {'params': [pos], 'lr': 0.5, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.5, \"name\": \"quat\"},\n", - "], lr=0.0)\n", + "optimizer = torch.optim.SGD(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.5, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.5, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + ")\n", "print(quat)\n", "pbar = tqdm(range(200))\n", "\n", "for it in pbar:\n", - " rendered_image, _ = render(pos,quat)\n", + " rendered_image, _ = render(pos, quat)\n", " loss = torch.abs((obs_depth - rendered_image) * mask_gt).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", @@ -395,7 +431,7 @@ " pbar.set_description(f\"{loss.item()}\")\n", "\n", "pose_optim = posevec_to_matrix_single(pos, quat)\n", - "viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", + "viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", "\n", "b.clear()\n", "b.show_trimesh(\"1\", m, color=b.BLUE, opacity=0.5)\n", @@ -423,10 +459,10 @@ "def single_gd(it=20000):\n", " optimizer = torch.optim.Adam([pos], betas=(0.9, 0.999), lr=1e-5)\n", " optimizer2 = torch.optim.Adam([quat], betas=(0.9, 0.999), lr=1e-4)\n", - " \n", + "\n", " pbar = tqdm(range(it))\n", " for it in pbar:\n", - " rendered_image, _ = render(pos,quat)\n", + " rendered_image, _ = render(pos, quat)\n", " loss = torch.abs((obs_depth - rendered_image) * mask_gt).mean()\n", " optimizer.zero_grad()\n", " optimizer2.zero_grad()\n", @@ -437,12 +473,20 @@ " optimizer2.step()\n", "\n", " if it % 200 == 0:\n", - " b.hstack_images([b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy())), viz_gt]).show()\n", + " b.hstack_images(\n", + " [\n", + " b.get_depth_image(\n", + " jnp.array(rendered_image[0].detach().cpu().numpy())\n", + " ),\n", + " viz_gt,\n", + " ]\n", + " ).show()\n", "\n", " pbar.set_description(f\"{loss.item()}\")\n", - " viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", + " viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", " b.hstack_images([viz, viz_gt])\n", - " \n", + "\n", + "\n", "# single_gd()" ] }, @@ -728,22 +772,26 @@ ], "source": [ "# Ground Truth\n", - "pos = torch.tensor([0.0, 0.0, 2.5],device=device)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", + "pos = torch.tensor([0.0, 0.0, 2.5], device=device)\n", + "quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", "obs_depth, mask_gt = render(pos, quat)\n", "viz_gt = b.get_depth_image(jnp.array(obs_depth[0].cpu().numpy()))\n", - "viz_mask_gt = b.get_depth_image(jnp.array(mask_gt[0].cpu().numpy()) * 1.0,max=1.1)\n", - "b.viz.hstack_images([viz_gt,viz_mask_gt]).show()\n", + "viz_mask_gt = b.get_depth_image(jnp.array(mask_gt[0].cpu().numpy()) * 1.0, max=1.1)\n", + "b.viz.hstack_images([viz_gt, viz_mask_gt]).show()\n", "\n", "# Hypotheses\n", "num_hypos = 20\n", - "pos = torch.tensor([[0.0, 0.0, 2.4] for _ in range(num_hypos)],device=device, requires_grad=True)\n", - "quat = torch.tensor(torch.rand(num_hypos, 4,device=device) - 0.5,device=device, requires_grad=True)\n", + "pos = torch.tensor(\n", + " [[0.0, 0.0, 2.4] for _ in range(num_hypos)], device=device, requires_grad=True\n", + ")\n", + "quat = torch.tensor(\n", + " torch.rand(num_hypos, 4, device=device) - 0.5, device=device, requires_grad=True\n", + ")\n", "\n", - "rendered_images, _ = render_batch(pos,quat)\n", + "rendered_images, _ = render_batch(pos, quat)\n", "for rendered_image in rendered_images:\n", - " viz = b.get_depth_image(jnp.array(rendered_image.detach().cpu().numpy()))\n", - " b.scale_image(b.hstack_images([viz, viz_gt]),0.2).show()" + " viz = b.get_depth_image(jnp.array(rendered_image.detach().cpu().numpy()))\n", + " b.scale_image(b.hstack_images([viz, viz_gt]), 0.2).show()" ] }, { @@ -970,13 +1018,14 @@ "source": [ "N = len(pos)\n", "assert N % 2 == 0 or N % 3 == 0 or N % 5 == 0\n", - "if N % 5 == 0 and 5 < N//5:\n", - " vh, vw = 5, N//5 \n", + "if N % 5 == 0 and 5 < N // 5:\n", + " vh, vw = 5, N // 5\n", "elif N % 3 == 0:\n", - " vh, vw = 3, N//3\n", + " vh, vw = 3, N // 3\n", "else:\n", - " vh, vw = 2, N//2\n", - " \n", + " vh, vw = 2, N // 2\n", + "\n", + "\n", "def multi_gd(it=20000):\n", " prev_min_loss_idx = -1\n", "\n", @@ -984,7 +1033,7 @@ " optimizer = torch.optim.Adam([pos], betas=(0.9, 0.999), lr=1e-5)\n", " optimizer2 = torch.optim.Adam([quat], betas=(0.9, 0.999), lr=2e-4)\n", " for it in pbar:\n", - " rendered_images, _ = render_batch(pos,quat)\n", + " rendered_images, _ = render_batch(pos, quat)\n", " diffs = torch.abs((obs_depth - rendered_images) * mask_gt)\n", "\n", " loss = diffs.mean() * N\n", @@ -998,7 +1047,7 @@ " optimizer2.step()\n", "\n", " with torch.no_grad():\n", - " img_diffs = torch.mean(diffs * N, dim=(1,2))\n", + " img_diffs = torch.mean(diffs * N, dim=(1, 2))\n", "\n", " min_loss_val, min_loss_idx = torch.min(img_diffs, 0)\n", "\n", @@ -1011,17 +1060,23 @@ " if it % 200 == 0:\n", " vizs = []\n", " for i, rendered_image in enumerate(rendered_images):\n", - " viz = b.get_depth_image(jnp.array(rendered_image.detach().cpu().numpy()))\n", - " if i == min_loss_idx: viz = get_img_with_border(viz, border=5)\n", + " viz = b.get_depth_image(\n", + " jnp.array(rendered_image.detach().cpu().numpy())\n", + " )\n", + " if i == min_loss_idx:\n", + " viz = get_img_with_border(viz, border=5)\n", " vizs.append(viz)\n", - " b.scale_image(b.hvstack_images(vizs, vh, vw), 0.2).show() \n", + " b.scale_image(b.hvstack_images(vizs, vh, vw), 0.2).show()\n", " b.scale_image(b.hstack_images([vizs[min_loss_idx], viz_gt]), 0.5).show()\n", "\n", " pbar.set_description(f\"{loss.item()}\")\n", "\n", - " best_viz = b.get_depth_image(jnp.array(rendered_images[min_loss_idx].detach().cpu().numpy()))\n", + " best_viz = b.get_depth_image(\n", + " jnp.array(rendered_images[min_loss_idx].detach().cpu().numpy())\n", + " )\n", " b.hstack_images([best_viz, viz_gt])\n", - " \n", + "\n", + "\n", "multi_gd(1000)" ] }, diff --git a/scripts/experiments/gaussian_splatting/nvdiffrast_diff.ipynb b/scripts/experiments/gaussian_splatting/nvdiffrast_diff.ipynb index 31162e15..5b592f2f 100644 --- a/scripts/experiments/gaussian_splatting/nvdiffrast_diff.ipynb +++ b/scripts/experiments/gaussian_splatting/nvdiffrast_diff.ipynb @@ -33,6 +33,7 @@ "import diffdope as dd\n", "import pytorch3d.transforms\n", "import matplotlib.pyplot as plt\n", + "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device" ] @@ -63,22 +64,22 @@ "metadata": {}, "outputs": [], "source": [ - "max_iter = 10000\n", - "repeats = 1\n", - "log_interval = 10\n", - "display_interval = None\n", - "display_res = 512\n", - "lr_base = 1e-3\n", - "lr_falloff = 1.0\n", - "nr_base = 1.0\n", - "nr_falloff = 1e-4\n", - "grad_phase_start = 0.5\n", - "resolution = [200,200]\n", - "out_dir = None\n", - "log_fn = None\n", - "mp4save_interval = None\n", - "mp4save_fn = None\n", - "use_opengl = False" + "max_iter = 10000\n", + "repeats = 1\n", + "log_interval = 10\n", + "display_interval = None\n", + "display_res = 512\n", + "lr_base = 1e-3\n", + "lr_falloff = 1.0\n", + "nr_base = 1.0\n", + "nr_falloff = 1e-4\n", + "grad_phase_start = 0.5\n", + "resolution = [200, 200]\n", + "out_dir = None\n", + "log_fn = None\n", + "mp4save_interval = None\n", + "mp4save_fn = None\n", + "use_opengl = False" ] }, { @@ -88,7 +89,7 @@ "metadata": {}, "outputs": [], "source": [ - "glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext()" + "glctx = dr.RasterizeGLContext() # if use_opengl else dr.RasterizeCudaContext()" ] }, { @@ -110,18 +111,18 @@ } ], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 14\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "m = b.utils.scale_mesh(m, 1.0/100.0)\n", + "m = b.utils.scale_mesh(m, 1.0 / 100.0)\n", "\n", "# m = b.utils.make_cuboid_mesh(jnp.array([0.5, 0.5, 0.2]))\n", "\n", "vtx_pos = torch.from_numpy(m.vertices.astype(np.float32)).cuda()\n", "pos_idx = torch.from_numpy(m.faces.astype(np.int32)).cuda()\n", - "col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0],3)).astype(np.int32)).cuda()\n", - "vtx_col = torch.from_numpy(np.ones((1,3)).astype(np.float32)).cuda()\n", + "col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0], 3)).astype(np.int32)).cuda()\n", + "vtx_col = torch.from_numpy(np.ones((1, 3)).astype(np.float32)).cuda()\n", "# print(\"Mesh has %d triangles and %d vertices.\" % (pos_idx.shape[0], pos.shape[0]))\n", "print(pos_idx.shape, vtx_pos.shape, col_idx.shape, vtx_col.shape)\n", "print(vtx_pos, vtx_col)" @@ -137,8 +138,14 @@ "def posevec_to_matrix(position, quat):\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quat), position.unsqueeze(1)), 1),\n", - " torch.tensor([[0.0, 0.0, 0.0, 1.0]],device=device),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quat),\n", + " position.unsqueeze(1),\n", + " ),\n", + " 1,\n", + " ),\n", + " torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device),\n", " ),\n", " 0,\n", " )" @@ -166,19 +173,24 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=200.0, fy=200.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=5.5\n", + " height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.01, far=5.5\n", ")\n", "\n", - "mvp = torch.tensor(np.array(b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", - ")), device=device)\n", + "mvp = torch.tensor(\n", + " np.array(\n", + " b.camera._open_gl_projection_matrix(\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", + " )\n", + " ),\n", + " device=device,\n", + ")\n", "mvp" ] }, @@ -192,21 +204,23 @@ "def render(pos, quat):\n", " mtx = posevec_to_matrix(pos, quat)\n", " proj_cam = mvp\n", - " pos = vtx_pos[None,...]\n", + " pos = vtx_pos[None, ...]\n", " posw = torch.cat([pos, torch.ones([pos.shape[0], pos.shape[1], 1]).cuda()], axis=2)\n", "\n", " final_mtx_proj = torch.matmul(proj_cam, mtx)\n", - " pos_clip_ja = dd.xfm_points(pos.contiguous(), final_mtx_proj[None,...])\n", + " pos_clip_ja = dd.xfm_points(pos.contiguous(), final_mtx_proj[None, ...])\n", " # pos_clip_ja = torch.matmul(posw, torch.transpose(torch.matmul(proj_cam, mtx),1,0))\n", - " rast_out, rast_out_db = dr.rasterize(glctx, pos_clip_ja, pos_idx, resolution=resolution)\n", + " rast_out, rast_out_db = dr.rasterize(\n", + " glctx, pos_clip_ja, pos_idx, resolution=resolution\n", + " )\n", "\n", " gb_pos, _ = dr.interpolate(posw, rast_out, pos_idx, rast_db=rast_out_db)\n", - " mask= rast_out[...,2] > 0\n", + " mask = rast_out[..., 2] > 0\n", " shape_keep = gb_pos.shape\n", " gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])\n", " gb_pos = gb_pos[..., :3]\n", "\n", - " depth = dd.xfm_points(gb_pos.contiguous(), mtx[None,...])\n", + " depth = dd.xfm_points(gb_pos.contiguous(), mtx[None, ...])\n", " depth = depth.reshape(shape_keep)[..., 2]\n", " return depth, mask" ] @@ -238,13 +252,13 @@ } ], "source": [ - "pos = torch.tensor([0.0, 0.0, 2.5],device=device)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", + "pos = torch.tensor([0.0, 0.0, 2.5], device=device)\n", + "quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", "obs_depth, mask = render(pos, quat)\n", "viz_gt = b.get_depth_image(jnp.array(obs_depth[0].cpu().numpy()))\n", - "mask_gt = b.get_depth_image(jnp.array(mask[0].cpu().numpy()) * 1.0,max=1.1)\n", + "mask_gt = b.get_depth_image(jnp.array(mask[0].cpu().numpy()) * 1.0, max=1.1)\n", "# print(depth[...,None])\n", - "b.viz.hstack_images([viz_gt,mask_gt])" + "b.viz.hstack_images([viz_gt, mask_gt])" ] }, { @@ -274,10 +288,12 @@ } ], "source": [ - "pos = torch.tensor([0.0, 0.0,2.4],device=device, requires_grad=True)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device, requires_grad=True)\n", - "rendered_image,_ = render(pos,quat)\n", - "viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", + "pos = torch.tensor([0.0, 0.0, 2.4], device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " torch.rand(4, device=device) - 0.5, device=device, requires_grad=True\n", + ")\n", + "rendered_image, _ = render(pos, quat)\n", + "viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", "b.hstack_images([viz, viz_gt])" ] }, @@ -315,21 +331,24 @@ } ], "source": [ - "optimizer = torch.optim.SGD([\n", - " {'params': [pos], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 1.0, \"name\": \"quat\"},\n", - "], lr=0.0)\n", + "optimizer = torch.optim.SGD(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 1.0, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + ")\n", "print(quat)\n", "\n", "pbar = tqdm(range(300))\n", "for _ in pbar:\n", - " rendered_image,mask2 = render(pos,quat)\n", + " rendered_image, mask2 = render(pos, quat)\n", " loss = torch.abs((obs_depth - rendered_image) * mask).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " pbar.set_description(f\"{loss.item()}\")\n", - "viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", + "viz = b.get_depth_image(jnp.array(rendered_image[0].detach().cpu().numpy()))\n", "b.hstack_images([viz, viz_gt])" ] }, @@ -342,8 +361,8 @@ "source": [ "b.clear()\n", "depth_np = jnp.array(depth[0].detach().cpu().numpy())\n", - "cloud = b.unproject_depth_jit(depth_np, intrinsics).reshape(-1,3)\n", - "b.show_cloud(\"1\",cloud)\n", + "cloud = b.unproject_depth_jit(depth_np, intrinsics).reshape(-1, 3)\n", + "b.show_cloud(\"1\", cloud)\n", "\n", "b.show_trimesh(\"m\", m)\n", "b.set_pose(\"m\", jnp.array(mtx.cpu().numpy()))" @@ -391,7 +410,7 @@ } ], "source": [ - "pose = jnp.array(mtx.cpu().numpy())\n", + "pose = jnp.array(mtx.cpu().numpy())\n", "pose" ] }, @@ -420,18 +439,15 @@ "import bayes3d as b\n", "import os\n", "import jax.numpy as jnp\n", + "\n", "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=200.0, fy=200.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=5.5\n", + " height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.01, far=5.5\n", ")\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 14\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)" + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)" ] }, { @@ -454,13 +470,15 @@ ], "source": [ "pose = jnp.array(\n", - " [[-0.45291662, 0.41978574, 0.78654087, 0. ],\n", - " [-0.4266136 , 0.6726148 , -0.60464054, 0. ],\n", - " [-0.78285843, -0.60940075, -0.12555218, 2.5 ],\n", - " [ 0. , 0. , 0. , 1. ]]\n", + " [\n", + " [-0.45291662, 0.41978574, 0.78654087, 0.0],\n", + " [-0.4266136, 0.6726148, -0.60464054, 0.0],\n", + " [-0.78285843, -0.60940075, -0.12555218, 2.5],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " ]\n", ")\n", - "img = b.RENDERER.render(pose[None,...], jnp.array([0]))\n", - "b.get_depth_image(img[...,2])" + "img = b.RENDERER.render(pose[None, ...], jnp.array([0]))\n", + "b.get_depth_image(img[..., 2])" ] }, { @@ -512,6 +530,7 @@ ], "source": [ "from bayes3d.viz.open3dviz import Open3DVisualizer\n", + "\n", "visualizer = Open3DVisualizer(intrinsics)" ] }, @@ -534,7 +553,9 @@ } ], "source": [ - "visualizer.make_mesh_from_file(mesh_path, jnp.array(mtx.cpu().numpy()), scaling_factor=1.0/1000.0)\n", + "visualizer.make_mesh_from_file(\n", + " mesh_path, jnp.array(mtx.cpu().numpy()), scaling_factor=1.0 / 1000.0\n", + ")\n", "rgbd = visualizer.capture_image(intrinsics, jnp.eye(4))\n", "b.get_rgb_image(rgbd.rgb)" ] @@ -569,7 +590,7 @@ ], "source": [ "b.setup_renderer(intrinsics)\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)" + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)" ] }, { @@ -669,7 +690,7 @@ } ], "source": [ - "depth[...,None]" + "depth[..., None]" ] }, { @@ -729,9 +750,9 @@ } ], "source": [ - "pos = torch.tensor([0.0, 0.0, -5.5],device=device)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", - "mtx = dd.matrix_batch_44_from_position_quat(quat[None,...],pos[None,...])\n", + "pos = torch.tensor([0.0, 0.0, -5.5], device=device)\n", + "quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", + "mtx = dd.matrix_batch_44_from_position_quat(quat[None, ...], pos[None, ...])\n", "camera = dd.Camera(\n", " fx=100.0,\n", " fy=100.0,\n", @@ -741,17 +762,15 @@ " im_height=100,\n", ")\n", "camera.cuda()\n", - "proj_cam = camera.cam_proj[None,...]\n", - "pos = vtx_pos[None,...]\n", + "proj_cam = camera.cam_proj[None, ...]\n", + "pos = vtx_pos[None, ...]\n", "\n", "posw = torch.cat([pos, torch.ones([pos.shape[0], pos.shape[1], 1]).cuda()], axis=2)\n", "\n", "final_mtx_proj = torch.matmul(proj_cam, mtx)\n", "pos_clip_ja = dd.xfm_points(pos.contiguous(), final_mtx_proj)\n", "\n", - "rast_out, rast_out_db = dr.rasterize(\n", - " glctx, pos_clip_ja, pos_idx, resolution=resolution\n", - ")\n", + "rast_out, rast_out_db = dr.rasterize(glctx, pos_clip_ja, pos_idx, resolution=resolution)\n", "\n", "gb_pos, _ = dr.interpolate(posw, rast_out, pos_idx, rast_db=rast_out_db)\n", "shape_keep = gb_pos.shape\n", @@ -790,10 +809,9 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "ground_truth_image = render_wrapper(pos,quat)\n", + "ground_truth_image = render_wrapper(pos, quat)\n", "viz_gt = get_viz(ground_truth_image)\n", - "viz_gt\n" + "viz_gt" ] }, { @@ -814,8 +832,8 @@ " Returns:\n", " Rotation matrices as tensor of shape (..., 3, 3).\n", " \"\"\"\n", - " positions = poses[...,:3]\n", - " quaternions = poses[...,3:]\n", + " positions = poses[..., :3]\n", + " quaternions = poses[..., 3:]\n", " r, i, j, k = torch.unbind(quaternions, -1)\n", " x, y, z = torch.unbind(positions, -1)\n", " # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.\n", @@ -845,6 +863,7 @@ " rotation_matrices = o.reshape(quaternions.shape[:-1] + (4, 4))\n", " return rotation_matrices\n", "\n", + "\n", "# Transform vertex positions to clip space\n", "def transform_pos(mtx, pos):\n", " t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx, np.ndarray) else mtx\n", @@ -852,16 +871,21 @@ " posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)\n", " return torch.matmul(posw, t_mtx.t())[None, ...]\n", "\n", + "\n", "def render(glctx, mtx, pos, pos_idx, resolution: int):\n", " # Setup TF graph for reference.\n", " depth_ = pos[..., 2:3]\n", - " depth = torch.tensor([[[(z_val/1)] for z_val in depth_.squeeze()]], dtype=torch.float32).cuda()\n", - " pos_clip = transform_pos(mtx, pos)\n", - " rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[resolution, resolution])\n", - " color , _ = dr.interpolate(depth, rast_out, pos_idx)\n", + " depth = torch.tensor(\n", + " [[[(z_val / 1)] for z_val in depth_.squeeze()]], dtype=torch.float32\n", + " ).cuda()\n", + " pos_clip = transform_pos(mtx, pos)\n", + " rast_out, _ = dr.rasterize(\n", + " glctx, pos_clip, pos_idx, resolution=[resolution, resolution]\n", + " )\n", + " color, _ = dr.interpolate(depth, rast_out, pos_idx)\n", " # color = dr.antialias(color, rast_out, pos_clip, pos_idx)\n", " return color\n", - " # return rast_out[:,:,:,2:3]\n" + " # return rast_out[:,:,:,2:3]" ] }, { @@ -872,13 +896,14 @@ "outputs": [], "source": [ "datadir = \"/home/nishadgothoskar/bayes3d/nvdiffrast/samples/data\"\n", - "with np.load(f'{datadir}/cube_p.npz') as f:\n", + "with np.load(f\"{datadir}/cube_p.npz\") as f:\n", " pos_idx, pos, col_idx, col = f.values()\n", "print(\"Mesh has %d triangles and %d vertices.\" % (pos_idx.shape[0], pos.shape[0]))\n", "\n", "# Some input geometry contains vertex positions in (N, 4) (with v[:,3]==1). Drop\n", "# the last column in that case.\n", - "if pos.shape[1] == 4: pos = pos[:, 0:3]\n", + "if pos.shape[1] == 4:\n", + " pos = pos[:, 0:3]\n", "\n", "# Create position/triangle index tensors\n", "pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()\n", @@ -913,11 +938,19 @@ "outputs": [], "source": [ "pose_target = torch.tensor([0.0, 0.0, -5.0, 1.0, 1.2, 0.4, 1.0]).cuda()\n", - "rast_target = render(glctx, torch.matmul(mvp, quaternion_to_matrix(pose_target)), vtx_pos, pos_idx, resolution)\n", - "img_target = rast_target[0].detach().cpu().numpy()\n", - "b.hstack_images([\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n" + "rast_target = render(\n", + " glctx,\n", + " torch.matmul(mvp, quaternion_to_matrix(pose_target)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " resolution,\n", + ")\n", + "img_target = rast_target[0].detach().cpu().numpy()\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -927,15 +960,28 @@ "metadata": {}, "outputs": [], "source": [ - "pose_opt = torch.tensor([0.0, 0.0, -6.0, 1.0, 1.2, 0.4, 1.0],dtype=torch.float32, device='cuda', requires_grad=True)\n", - "loss_best = np.inf\n", + "pose_opt = torch.tensor(\n", + " [0.0, 0.0, -6.0, 1.0, 1.2, 0.4, 1.0],\n", + " dtype=torch.float32,\n", + " device=\"cuda\",\n", + " requires_grad=True,\n", + ")\n", + "loss_best = np.inf\n", "\n", - "rast_opt = render(glctx, torch.matmul(mvp, quaternion_to_matrix(pose_opt)), vtx_pos, pos_idx, resolution)\n", - "img_opt = rast_opt[0].detach().cpu().numpy()\n", - "b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n" + "rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, quaternion_to_matrix(pose_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " resolution,\n", + ")\n", + "img_opt = rast_opt[0].detach().cpu().numpy()\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -945,7 +991,7 @@ "metadata": {}, "outputs": [], "source": [ - "diff = (rast_opt - rast_target)**2 # L2 norm.\n", + "diff = (rast_opt - rast_target) ** 2 # L2 norm.\n", "diff.sum()" ] }, @@ -956,33 +1002,41 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer = torch.optim.Adam([pose_opt], lr=0.00001)\n", + "optimizer = torch.optim.Adam([pose_opt], lr=0.00001)\n", "images = []\n", "\n", - "for _ in tqdm(range(200)): \n", - " rast_opt = render(glctx, torch.matmul(mvp, quaternion_to_matrix(pose_opt)), vtx_pos, pos_idx, resolution)\n", + "for _ in tqdm(range(200)):\n", + " rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, quaternion_to_matrix(pose_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " resolution,\n", + " )\n", "\n", - " diff = (rast_opt - rast_target)**2 # L2 norm.\n", + " diff = (rast_opt - rast_target) ** 2 # L2 norm.\n", " loss = torch.mean(diff)\n", " loss_val = float(loss)\n", - " \n", + "\n", " if (loss_val < loss_best) and (loss_val > 0.0):\n", " loss_best = loss_val\n", - " \n", + "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", - " print(loss) \n", + " print(loss)\n", " with torch.no_grad():\n", - " pose_opt /= torch.sum(pose_opt**2)**0.5\n", - " \n", - " img_opt = rast_opt[0].detach().cpu().numpy()\n", + " pose_opt /= torch.sum(pose_opt**2) ** 0.5\n", + "\n", + " img_opt = rast_opt[0].detach().cpu().numpy()\n", " images.append(\n", - " b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - " ])\n", + " b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + " )\n", " )" ] }, @@ -993,7 +1047,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.vstack_images([images[0],images[-1]])" + "b.vstack_images([images[0], images[-1]])" ] }, { diff --git a/scripts/experiments/gaussian_splatting/nvdiffrast_diff_original.ipynb b/scripts/experiments/gaussian_splatting/nvdiffrast_diff_original.ipynb index 9fa0af0f..69c9c89a 100644 --- a/scripts/experiments/gaussian_splatting/nvdiffrast_diff_original.ipynb +++ b/scripts/experiments/gaussian_splatting/nvdiffrast_diff_original.ipynb @@ -17,7 +17,7 @@ "import bayes3d as b\n", "from tqdm import tqdm\n", "\n", - "import nvdiffrast.torch as dr\n" + "import nvdiffrast.torch as dr" ] }, { @@ -27,22 +27,22 @@ "metadata": {}, "outputs": [], "source": [ - "max_iter = 10000\n", - "repeats = 1\n", - "log_interval = 10\n", - "display_interval = None\n", - "display_res = 512\n", - "lr_base = 0.01\n", - "lr_falloff = 1.0\n", - "nr_base = 1.0\n", - "nr_falloff = 1e-4\n", - "grad_phase_start = 0.5\n", - "resolution = 256\n", - "out_dir = None\n", - "log_fn = None\n", - "mp4save_interval = None\n", - "mp4save_fn = None\n", - "use_opengl = False" + "max_iter = 10000\n", + "repeats = 1\n", + "log_interval = 10\n", + "display_interval = None\n", + "display_res = 512\n", + "lr_base = 0.01\n", + "lr_falloff = 1.0\n", + "nr_base = 1.0\n", + "nr_falloff = 1e-4\n", + "grad_phase_start = 0.5\n", + "resolution = 256\n", + "out_dir = None\n", + "log_fn = None\n", + "mp4save_interval = None\n", + "mp4save_fn = None\n", + "use_opengl = False" ] }, { @@ -53,17 +53,27 @@ "outputs": [], "source": [ "def projection(x=0.1, n=1.0, f=50.0):\n", - " return np.array([[n/x, 0, 0, 0],\n", - " [ 0, n/x, 0, 0],\n", - " [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],\n", - " [ 0, 0, -1, 0]]).astype(np.float32)\n", + " return np.array(\n", + " [\n", + " [n / x, 0, 0, 0],\n", + " [0, n / x, 0, 0],\n", + " [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],\n", + " [0, 0, -1, 0],\n", + " ]\n", + " ).astype(np.float32)\n", + "\n", + "\n", "def translate(x, y, z):\n", - " return np.array([[1, 0, 0, x],\n", - " [0, 1, 0, y],\n", - " [0, 0, 1, z],\n", - " [0, 0, 0, 1]]).astype(np.float32)\n", - "glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext()\n", - "mvp = torch.tensor(np.matmul(projection(x=0.4), translate(0, 0, -3.5)).astype(np.float32), device='cuda')\n" + " return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]).astype(\n", + " np.float32\n", + " )\n", + "\n", + "\n", + "glctx = dr.RasterizeGLContext() # if use_opengl else dr.RasterizeCudaContext()\n", + "mvp = torch.tensor(\n", + " np.matmul(projection(x=0.4), translate(0, 0, -3.5)).astype(np.float32),\n", + " device=\"cuda\",\n", + ")" ] }, { @@ -73,33 +83,65 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "#----------------------------------------------------------------------------\n", + "# ----------------------------------------------------------------------------\n", "# Quaternion math.\n", - "#----------------------------------------------------------------------------\n", + "# ----------------------------------------------------------------------------\n", "\n", "# Unit quaternion.\n", "def q_unit():\n", " return np.asarray([1, 0, 0, 0], np.float32)\n", "\n", + "\n", "# Get a random normalized quaternion.\n", "def q_rnd():\n", " u, v, w = np.random.uniform(0.0, 1.0, size=[3])\n", " v *= 2.0 * np.pi\n", " w *= 2.0 * np.pi\n", - " return np.asarray([(1.0-u)**0.5 * np.sin(v), (1.0-u)**0.5 * np.cos(v), u**0.5 * np.sin(w), u**0.5 * np.cos(w)], np.float32)\n", + " return np.asarray(\n", + " [\n", + " (1.0 - u) ** 0.5 * np.sin(v),\n", + " (1.0 - u) ** 0.5 * np.cos(v),\n", + " u**0.5 * np.sin(w),\n", + " u**0.5 * np.cos(w),\n", + " ],\n", + " np.float32,\n", + " )\n", + "\n", "\n", "# Get a random quaternion from the octahedral symmetric group S_4.\n", "_r2 = 0.5**0.5\n", - "_q_S4 = [[ 1.0, 0.0, 0.0, 0.0], [ 0.0, 1.0, 0.0, 0.0], [ 0.0, 0.0, 1.0, 0.0], [ 0.0, 0.0, 0.0, 1.0],\n", - " [-0.5, 0.5, 0.5, 0.5], [-0.5,-0.5,-0.5, 0.5], [ 0.5,-0.5, 0.5, 0.5], [ 0.5, 0.5,-0.5, 0.5],\n", - " [ 0.5, 0.5, 0.5, 0.5], [-0.5, 0.5,-0.5, 0.5], [ 0.5,-0.5,-0.5, 0.5], [-0.5,-0.5, 0.5, 0.5],\n", - " [ _r2,-_r2, 0.0, 0.0], [ _r2, _r2, 0.0, 0.0], [ 0.0, 0.0, _r2, _r2], [ 0.0, 0.0,-_r2, _r2],\n", - " [ 0.0, _r2, _r2, 0.0], [ _r2, 0.0, 0.0,-_r2], [ _r2, 0.0, 0.0, _r2], [ 0.0,-_r2, _r2, 0.0],\n", - " [ _r2, 0.0, _r2, 0.0], [ 0.0, _r2, 0.0, _r2], [ _r2, 0.0,-_r2, 0.0], [ 0.0,-_r2, 0.0, _r2]]\n", + "_q_S4 = [\n", + " [1.0, 0.0, 0.0, 0.0],\n", + " [0.0, 1.0, 0.0, 0.0],\n", + " [0.0, 0.0, 1.0, 0.0],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " [-0.5, 0.5, 0.5, 0.5],\n", + " [-0.5, -0.5, -0.5, 0.5],\n", + " [0.5, -0.5, 0.5, 0.5],\n", + " [0.5, 0.5, -0.5, 0.5],\n", + " [0.5, 0.5, 0.5, 0.5],\n", + " [-0.5, 0.5, -0.5, 0.5],\n", + " [0.5, -0.5, -0.5, 0.5],\n", + " [-0.5, -0.5, 0.5, 0.5],\n", + " [_r2, -_r2, 0.0, 0.0],\n", + " [_r2, _r2, 0.0, 0.0],\n", + " [0.0, 0.0, _r2, _r2],\n", + " [0.0, 0.0, -_r2, _r2],\n", + " [0.0, _r2, _r2, 0.0],\n", + " [_r2, 0.0, 0.0, -_r2],\n", + " [_r2, 0.0, 0.0, _r2],\n", + " [0.0, -_r2, _r2, 0.0],\n", + " [_r2, 0.0, _r2, 0.0],\n", + " [0.0, _r2, 0.0, _r2],\n", + " [_r2, 0.0, -_r2, 0.0],\n", + " [0.0, -_r2, 0.0, _r2],\n", + "]\n", + "\n", + "\n", "def q_rnd_S4():\n", " return np.asarray(_q_S4[np.random.randint(24)], np.float32)\n", "\n", + "\n", "# Quaternion slerp.\n", "def q_slerp(p, q, t):\n", " d = np.dot(p, q)\n", @@ -107,28 +149,31 @@ " q = -q\n", " d = -d\n", " if d > 0.999:\n", - " a = p + t * (q-p)\n", + " a = p + t * (q - p)\n", " return a / np.linalg.norm(a)\n", " t0 = np.arccos(d)\n", " tt = t0 * t\n", " st = np.sin(tt)\n", " st0 = np.sin(t0)\n", " s1 = st / st0\n", - " s0 = np.cos(tt) - d*s1\n", - " return s0*p + s1*q\n", + " s0 = np.cos(tt) - d * s1\n", + " return s0 * p + s1 * q\n", + "\n", "\n", "# Quaterion scale (slerp vs. identity quaternion).\n", "def q_scale(q, scl):\n", " return q_slerp(q_unit(), q, scl)\n", "\n", + "\n", "# Quaternion product.\n", "def q_mul(p, q):\n", " s1, V1 = p[0], p[1:]\n", " s2, V2 = q[0], q[1:]\n", - " s = s1*s2 - np.dot(V1, V2)\n", - " V = s1*V2 + s2*V1 + np.cross(V1, V2)\n", + " s = s1 * s2 - np.dot(V1, V2)\n", + " V = s1 * V2 + s2 * V1 + np.cross(V1, V2)\n", " return np.asarray([s, V[0], V[1], V[2]], np.float32)\n", "\n", + "\n", "# Angular difference between two quaternions in degrees.\n", "def q_angle_deg(p, q):\n", " p = p.detach().cpu().numpy()\n", @@ -137,24 +182,49 @@ " d = min(d, 1.0)\n", " return np.degrees(2.0 * np.arccos(d))\n", "\n", + "\n", "# Quaternion product\n", "def q_mul_torch(p, q):\n", - " a = p[0]*q[0] - p[1]*q[1] - p[2]*q[2] - p[3]*q[3]\n", - " b = p[0]*q[1] + p[1]*q[0] + p[2]*q[3] - p[3]*q[2]\n", - " c = p[0]*q[2] + p[2]*q[0] + p[3]*q[1] - p[1]*q[3]\n", - " d = p[0]*q[3] + p[3]*q[0] + p[1]*q[2] - p[2]*q[1]\n", + " a = p[0] * q[0] - p[1] * q[1] - p[2] * q[2] - p[3] * q[3]\n", + " b = p[0] * q[1] + p[1] * q[0] + p[2] * q[3] - p[3] * q[2]\n", + " c = p[0] * q[2] + p[2] * q[0] + p[3] * q[1] - p[1] * q[3]\n", + " d = p[0] * q[3] + p[3] * q[0] + p[1] * q[2] - p[2] * q[1]\n", " return torch.stack([a, b, c, d])\n", "\n", + "\n", "# Convert quaternion to 4x4 rotation matrix.\n", "def q_to_mtx(q):\n", - " r0 = torch.stack([1.0-2.0*q[1]**2 - 2.0*q[2]**2, 2.0*q[0]*q[1] - 2.0*q[2]*q[3], 2.0*q[0]*q[2] + 2.0*q[1]*q[3]])\n", - " r1 = torch.stack([2.0*q[0]*q[1] + 2.0*q[2]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[2]**2, 2.0*q[1]*q[2] - 2.0*q[0]*q[3]])\n", - " r2 = torch.stack([2.0*q[0]*q[2] - 2.0*q[1]*q[3], 2.0*q[1]*q[2] + 2.0*q[0]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[1]**2])\n", + " r0 = torch.stack(\n", + " [\n", + " 1.0 - 2.0 * q[1] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[0] * q[1] - 2.0 * q[2] * q[3],\n", + " 2.0 * q[0] * q[2] + 2.0 * q[1] * q[3],\n", + " ]\n", + " )\n", + " r1 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[1] + 2.0 * q[2] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[1] * q[2] - 2.0 * q[0] * q[3],\n", + " ]\n", + " )\n", + " r2 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[2] - 2.0 * q[1] * q[3],\n", + " 2.0 * q[1] * q[2] + 2.0 * q[0] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[1] ** 2,\n", + " ]\n", + " )\n", " rr = torch.transpose(torch.stack([r0, r1, r2]), 1, 0)\n", - " rr = torch.cat([rr, torch.tensor([[0], [0], [0]], dtype=torch.float32).cuda()], dim=1) # Pad right column.\n", - " rr = torch.cat([rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0) # Pad bottom row.\n", + " rr = torch.cat(\n", + " [rr, torch.tensor([[0], [0], [0]], dtype=torch.float32).cuda()], dim=1\n", + " ) # Pad right column.\n", + " rr = torch.cat(\n", + " [rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0\n", + " ) # Pad bottom row.\n", " return rr\n", "\n", + "\n", "# Transform vertex positions to clip space\n", "def transform_pos(mtx, pos):\n", " t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx, np.ndarray) else mtx\n", @@ -162,16 +232,21 @@ " posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)\n", " return torch.matmul(posw, t_mtx.t())[None, ...]\n", "\n", + "\n", "def render(glctx, mtx, pos, pos_idx, col, col_idx, resolution: int):\n", " # Setup TF graph for reference.\n", " depth_ = pos[..., 2:3]\n", - " depth = torch.tensor([[[(z_val/1)] for z_val in depth_.squeeze()]], dtype=torch.float32).cuda()\n", - " pos_clip = transform_pos(mtx, pos)\n", - " rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[resolution, resolution])\n", - " color , _ = dr.interpolate(depth, rast_out, pos_idx)\n", + " depth = torch.tensor(\n", + " [[[(z_val / 1)] for z_val in depth_.squeeze()]], dtype=torch.float32\n", + " ).cuda()\n", + " pos_clip = transform_pos(mtx, pos)\n", + " rast_out, _ = dr.rasterize(\n", + " glctx, pos_clip, pos_idx, resolution=[resolution, resolution]\n", + " )\n", + " color, _ = dr.interpolate(depth, rast_out, pos_idx)\n", " # color = dr.antialias(color, rast_out, pos_clip, pos_idx)\n", " return color\n", - " # return rast_out[:,:,:,2:3]\n" + " # return rast_out[:,:,:,2:3]" ] }, { @@ -182,13 +257,14 @@ "outputs": [], "source": [ "datadir = \"/home/nishadgothoskar/bayes3d/nvdiffrast/samples/data/\"\n", - "with np.load(f'{datadir}/cube_p.npz') as f:\n", + "with np.load(f\"{datadir}/cube_p.npz\") as f:\n", " pos_idx, pos, col_idx, col = f.values()\n", "print(\"Mesh has %d triangles and %d vertices.\" % (pos_idx.shape[0], pos.shape[0]))\n", "\n", "# Some input geometry contains vertex positions in (N, 4) (with v[:,3]==1). Drop\n", "# the last column in that case.\n", - "if pos.shape[1] == 4: pos = pos[:, 0:3]\n", + "if pos.shape[1] == 4:\n", + " pos = pos[:, 0:3]\n", "\n", "# Create position/triangle index tensors\n", "pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()\n", @@ -217,24 +293,48 @@ "metadata": {}, "outputs": [], "source": [ - "orientation_target = torch.tensor(q_rnd(), device='cuda')\n", - "position_target = torch.tensor(q_rnd(), device='cuda')\n", - "rast_target = render(glctx, torch.matmul(mvp, q_to_mtx(pose_target)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_target = rast_target[0].detach().cpu().numpy()\n", - "b.hstack_images([\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n", - "pose_init = pose_target.cpu().numpy() + 0.3\n", - "pose_opt = torch.tensor(pose_init / np.sum(pose_init**2)**0.5, dtype=torch.float32, device='cuda', requires_grad=True)\n", - "loss_best = np.inf\n", - "\n", - "rast_opt = render(glctx, torch.matmul(mvp, q_to_mtx(pose_opt)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_opt = rast_opt[0].detach().cpu().numpy()\n", - "b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n", - "\n" + "orientation_target = torch.tensor(q_rnd(), device=\"cuda\")\n", + "position_target = torch.tensor(q_rnd(), device=\"cuda\")\n", + "rast_target = render(\n", + " glctx,\n", + " torch.matmul(mvp, q_to_mtx(pose_target)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + ")\n", + "img_target = rast_target[0].detach().cpu().numpy()\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")\n", + "pose_init = pose_target.cpu().numpy() + 0.3\n", + "pose_opt = torch.tensor(\n", + " pose_init / np.sum(pose_init**2) ** 0.5,\n", + " dtype=torch.float32,\n", + " device=\"cuda\",\n", + " requires_grad=True,\n", + ")\n", + "loss_best = np.inf\n", + "\n", + "rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, q_to_mtx(pose_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + ")\n", + "img_opt = rast_opt[0].detach().cpu().numpy()\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -250,31 +350,41 @@ "for _ in pbar:\n", " noise = q_unit()\n", " pose_total_opt = q_mul_torch(pose_opt, noise)\n", - " mtx_total_opt = torch.matmul(mvp, q_to_mtx(pose_total_opt))\n", - " rast_opt = render(glctx, torch.matmul(mvp, q_to_mtx(pose_total_opt)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", + " mtx_total_opt = torch.matmul(mvp, q_to_mtx(pose_total_opt))\n", + " rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, q_to_mtx(pose_total_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + " )\n", "\n", - " diff = (rast_opt - rast_target)**2 # L2 norm.\n", + " diff = (rast_opt - rast_target) ** 2 # L2 norm.\n", " diff = torch.tanh(5.0 * torch.max(diff, dim=-1)[0])\n", " loss = torch.mean(diff)\n", " loss_val = float(loss)\n", " pbar.set_description(f\"{loss_val}\")\n", - " \n", + "\n", " if (loss_val < loss_best) and (loss_val > 0.0):\n", " loss_best = loss_val\n", - " \n", + "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", - " \n", + "\n", " with torch.no_grad():\n", - " pose_opt /= torch.sum(pose_opt**2)**0.5\n", - " \n", - " img_opt = rast_opt[0].detach().cpu().numpy()\n", + " pose_opt /= torch.sum(pose_opt**2) ** 0.5\n", + "\n", + " img_opt = rast_opt[0].detach().cpu().numpy()\n", " images.append(\n", - "b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n", + " b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + " )\n", " )" ] }, @@ -285,7 +395,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.vstack_images([images[0],images[-1]])" + "b.vstack_images([images[0], images[-1]])" ] }, { diff --git a/scripts/experiments/gaussian_splatting/nvdiffrast_jax_optim.ipynb b/scripts/experiments/gaussian_splatting/nvdiffrast_jax_optim.ipynb index a0032d01..3a9a0e66 100644 --- a/scripts/experiments/gaussian_splatting/nvdiffrast_jax_optim.ipynb +++ b/scripts/experiments/gaussian_splatting/nvdiffrast_jax_optim.ipynb @@ -21,33 +21,38 @@ "import jax.numpy as jnp\n", "from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer\n", "import matplotlib.pyplot as plt\n", + "\n", + "\n", "def projection(x=0.1, n=1.0, f=50.0):\n", - " return np.array([[n/x, 0, 0, 0],\n", - " [ 0, n/x, 0, 0],\n", - " [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],\n", - " [ 0, 0, -1, 0]]).astype(np.float32)\n", + " return np.array(\n", + " [\n", + " [n / x, 0, 0, 0],\n", + " [0, n / x, 0, 0],\n", + " [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],\n", + " [0, 0, -1, 0],\n", + " ]\n", + " ).astype(np.float32)\n", + "\n", + "\n", "def translate(x, y, z):\n", - " return np.array([[1, 0, 0, x],\n", - " [0, 1, 0, y],\n", - " [0, 0, 1, z],\n", - " [0, 0, 0, 1]]).astype(np.float32)\n", + " return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]).astype(\n", + " np.float32\n", + " )\n", + "\n", "\n", "mvp = np.matmul(projection(x=0.4), translate(0, 0, 0.0)).astype(np.float32)\n", "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=200.0, fy=200.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=5.5\n", + " height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.01, far=5.5\n", ")\n", "jax_renderer = JaxRenderer(intrinsics)\n", "\n", "import os\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 14\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "m = b.utils.scale_mesh(m, 1.0/100.0)\n", + "m = b.utils.scale_mesh(m, 1.0 / 100.0)\n", "\n", "pos = jnp.array(m.vertices.astype(np.float32))\n", "pos_idx = jnp.array(m.faces.astype(np.int32))" @@ -98,9 +103,11 @@ ], "source": [ "def xfm_points(points, matrix):\n", - " points = jnp.concatenate([points, jnp.ones((*points.shape[:-1],1))], axis=-1)\n", + " points = jnp.concatenate([points, jnp.ones((*points.shape[:-1], 1))], axis=-1)\n", " return jnp.matmul(points, matrix.T)\n", - "xfm_points(jnp.zeros((10,3)), jnp.eye(4))" + "\n", + "\n", + "xfm_points(jnp.zeros((10, 3)), jnp.eye(4))" ] }, { @@ -116,16 +123,25 @@ "# colors,_ = jax_renderer.interpolate(vertices_pose_transformed[:,2:3][None,...], rast_out, face_indices, rast_out_db, jnp.array([0]))\n", "key = jax.random.PRNGKey(15)\n", "\n", - "def render_img(translation,quat):\n", - " pose = b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), translation)\n", + "\n", + "def render_img(translation, quat):\n", + " pose = b.transform_from_rot_and_pos(\n", + " b.quaternion_to_rotation_matrix(quat), translation\n", + " )\n", "\n", " final_mtx_proj = mvp @ pose\n", "\n", - " posw = jnp.concatenate([pos, jnp.ones((*pos.shape[:-1],1))], axis=-1)\n", + " posw = jnp.concatenate([pos, jnp.ones((*pos.shape[:-1], 1))], axis=-1)\n", " pos_clip_ja = xfm_points(pos, final_mtx_proj)\n", "\n", - " rast_out, rast_out_db = jax_renderer.rasterize(pos_clip_ja[None,...], pos_idx, jnp.array([intrinsics.height, intrinsics.width]))\n", - " gb_pos,_ = jax_renderer.interpolate(posw[None,...], rast_out, pos_idx, rast_out_db, jnp.array([0,1,2,3]))\n", + " rast_out, rast_out_db = jax_renderer.rasterize(\n", + " pos_clip_ja[None, ...],\n", + " pos_idx,\n", + " jnp.array([intrinsics.height, intrinsics.width]),\n", + " )\n", + " gb_pos, _ = jax_renderer.interpolate(\n", + " posw[None, ...], rast_out, pos_idx, rast_out_db, jnp.array([0, 1, 2, 3])\n", + " )\n", " mask = rast_out[..., -1] > 0\n", " shape_keep = gb_pos.shape\n", " gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])\n", @@ -134,11 +150,13 @@ " depth = depth.reshape(shape_keep)[..., 2] * -1\n", " return depth, mask\n", "\n", + "\n", "def loss(translation, quat, colors_gt, mask):\n", - " colors, _ = render_img(translation,quat)\n", + " colors, _ = render_img(translation, quat)\n", " return jnp.mean(jnp.abs(colors_gt - colors) * mask)\n", "\n", - "grad_func = jax.jit(jax.value_and_grad(loss, argnums=(0,1)))" + "\n", + "grad_func = jax.jit(jax.value_and_grad(loss, argnums=(0, 1)))" ] }, { @@ -178,11 +196,11 @@ "source": [ "key = jax.random.split(key, 2)[0]\n", "translation, quat = jnp.array([0.0, 0.0, -3.2]), jax.random.uniform(key, shape=(4,))\n", - "colors_gt, gt_mask = render_img(translation,quat)\n", + "colors_gt, gt_mask = render_img(translation, quat)\n", "print(colors_gt.shape)\n", "print(gt_mask.shape)\n", - "fig,ax = plt.subplots(1,2)\n", - "ax[0].imshow(colors_gt[0,...])\n", + "fig, ax = plt.subplots(1, 2)\n", + "ax[0].imshow(colors_gt[0, ...])\n", "ax[1].imshow(gt_mask[0])" ] }, @@ -238,7 +256,7 @@ "translation, quat = jnp.array([0.0, 0.0, -3.4]), jax.random.uniform(key, shape=(4,))\n", "loss, (pos_grad, quat_grad) = grad_func(translation, quat, colors_gt, gt_mask)\n", "print(loss)\n", - "begining_img,_ = render_img(translation,quat)\n", + "begining_img, _ = render_img(translation, quat)\n", "pbar = tqdm(range(2000))\n", "for _ in pbar:\n", " loss, (pos_grad, quat_grad) = grad_func(translation, quat, colors_gt, gt_mask)\n", @@ -246,11 +264,11 @@ " translation = translation - 0.01 * pos_grad\n", " quat = quat - 1.0 * quat_grad\n", "\n", - "img,_ = render_img(translation,quat)\n", - "fig,ax = plt.subplots(3)\n", - "ax[0].imshow(colors_gt[0,...])\n", - "ax[1].imshow(begining_img[0,...])\n", - "ax[2].imshow(img[0,...])" + "img, _ = render_img(translation, quat)\n", + "fig, ax = plt.subplots(3)\n", + "ax[0].imshow(colors_gt[0, ...])\n", + "ax[1].imshow(begining_img[0, ...])\n", + "ax[2].imshow(img[0, ...])" ] } ], diff --git a/scripts/experiments/gaussian_splatting/nvdiffrast_optim.ipynb b/scripts/experiments/gaussian_splatting/nvdiffrast_optim.ipynb index 085aab5b..5533d6b6 100644 --- a/scripts/experiments/gaussian_splatting/nvdiffrast_optim.ipynb +++ b/scripts/experiments/gaussian_splatting/nvdiffrast_optim.ipynb @@ -30,8 +30,9 @@ "import pytorch3d.transforms\n", "import jax.numpy as jnp\n", "import nvdiffrast.torch as dr\n", + "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device\n" + "device" ] }, { @@ -41,22 +42,22 @@ "metadata": {}, "outputs": [], "source": [ - "max_iter = 10000\n", - "repeats = 1\n", - "log_interval = 10\n", - "display_interval = None\n", - "display_res = 512\n", - "lr_base = 1e-3\n", - "lr_falloff = 1.0\n", - "nr_base = 1.0\n", - "nr_falloff = 1e-4\n", - "grad_phase_start = 0.5\n", - "resolution = 128\n", - "out_dir = None\n", - "log_fn = None\n", - "mp4save_interval = None\n", - "mp4save_fn = None\n", - "use_opengl = False" + "max_iter = 10000\n", + "repeats = 1\n", + "log_interval = 10\n", + "display_interval = None\n", + "display_res = 512\n", + "lr_base = 1e-3\n", + "lr_falloff = 1.0\n", + "nr_base = 1.0\n", + "nr_falloff = 1e-4\n", + "grad_phase_start = 0.5\n", + "resolution = 128\n", + "out_dir = None\n", + "log_fn = None\n", + "mp4save_interval = None\n", + "mp4save_fn = None\n", + "use_opengl = False" ] }, { @@ -67,17 +68,26 @@ "outputs": [], "source": [ "def projection(x=0.1, n=1.0, f=50.0):\n", - " return np.array([[n/x, 0, 0, 0],\n", - " [ 0, n/x, 0, 0],\n", - " [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],\n", - " [ 0, 0, -1, 0]]).astype(np.float32)\n", + " return np.array(\n", + " [\n", + " [n / x, 0, 0, 0],\n", + " [0, n / x, 0, 0],\n", + " [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],\n", + " [0, 0, -1, 0],\n", + " ]\n", + " ).astype(np.float32)\n", + "\n", + "\n", "def translate(x, y, z):\n", - " return np.array([[1, 0, 0, x],\n", - " [0, 1, 0, y],\n", - " [0, 0, 1, z],\n", - " [0, 0, 0, 1]]).astype(np.float32)\n", - "glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext()\n", - "mvp = torch.tensor(np.matmul(projection(x=0.4), translate(0, 0, 0.0)).astype(np.float32), device='cuda')" + " return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]).astype(\n", + " np.float32\n", + " )\n", + "\n", + "\n", + "glctx = dr.RasterizeGLContext() # if use_opengl else dr.RasterizeCudaContext()\n", + "mvp = torch.tensor(\n", + " np.matmul(projection(x=0.4), translate(0, 0, 0.0)).astype(np.float32), device=\"cuda\"\n", + ")" ] }, { @@ -98,8 +108,8 @@ " Returns:\n", " Rotation matrices as tensor of shape (..., 3, 3).\n", " \"\"\"\n", - " positions = poses[...,:3]\n", - " quaternions = poses[...,3:]\n", + " positions = poses[..., :3]\n", + " quaternions = poses[..., 3:]\n", " r, i, j, k = torch.unbind(quaternions, -1)\n", " x, y, z = torch.unbind(positions, -1)\n", " # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.\n", @@ -129,6 +139,7 @@ " rotation_matrices = o.reshape(quaternions.shape[:-1] + (4, 4))\n", " return rotation_matrices\n", "\n", + "\n", "# Transform vertex positions to clip space\n", "def transform_pos(mtx, pos):\n", " t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx, np.ndarray) else mtx\n", @@ -136,16 +147,21 @@ " posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)\n", " return torch.matmul(posw, t_mtx.t())[None, ...]\n", "\n", + "\n", "def render(glctx, mtx, pos, pos_idx, resolution: int):\n", " # Setup TF graph for reference.\n", " depth_ = pos[..., 2:3]\n", - " depth = torch.tensor([[[(z_val/1)] for z_val in depth_.squeeze()]], dtype=torch.float32).cuda()\n", - " pos_clip = transform_pos(mtx, pos)\n", - " rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[resolution, resolution])\n", - " color , _ = dr.interpolate(depth, rast_out, pos_idx)\n", + " depth = torch.tensor(\n", + " [[[(z_val / 1)] for z_val in depth_.squeeze()]], dtype=torch.float32\n", + " ).cuda()\n", + " pos_clip = transform_pos(mtx, pos)\n", + " rast_out, _ = dr.rasterize(\n", + " glctx, pos_clip, pos_idx, resolution=[resolution, resolution]\n", + " )\n", + " color, _ = dr.interpolate(depth, rast_out, pos_idx)\n", " # color = dr.antialias(color, rast_out, pos_clip, pos_idx)\n", " return color\n", - " # return rast_out[:,:,:,2:3]\n" + " # return rast_out[:,:,:,2:3]" ] }, { @@ -170,23 +186,33 @@ "def posevec_to_matrix(position, quat):\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quat), position.unsqueeze(1)), 1),\n", - " torch.tensor([[0.0, 0.0, 0.0, 1.0]],device=device),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quat),\n", + " position.unsqueeze(1),\n", + " ),\n", + " 1,\n", + " ),\n", + " torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device),\n", " ),\n", " 0,\n", " )\n", + "\n", + "\n", "def apply_transform(points, transform):\n", " rels_ = torch.cat(\n", " (\n", " points,\n", - " torch.ones((points.shape[0], 1), device=device),\n", + " torch.ones((points.shape[0], 1), device=device),\n", " ),\n", " 1,\n", " )\n", - " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[...,:3]\n", + " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[..., :3]\n", + "\n", + "\n", "position = torch.tensor([0.0, 0.1, 0.2], device=device)\n", - "quat = torch.tensor([1.0, 0.1, 0.2, 0.3],device=device)\n", - "points = torch.zeros((5,3), device = device)\n", + "quat = torch.tensor([1.0, 0.1, 0.2, 0.3], device=device)\n", + "points = torch.zeros((5, 3), device=device)\n", "print(apply_transform(points, posevec_to_matrix(position, quat)))" ] }, @@ -212,18 +238,18 @@ } ], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 14\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "m = b.utils.scale_mesh(m, 1.0/100.0)\n", + "m = b.utils.scale_mesh(m, 1.0 / 100.0)\n", "\n", "# m = b.utils.make_cuboid_mesh(jnp.array([0.5, 0.5, 0.2]))\n", "\n", "vtx_pos = torch.from_numpy(m.vertices.astype(np.float32)).cuda()\n", "pos_idx = torch.from_numpy(m.faces.astype(np.int32)).cuda()\n", - "col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0],3)).astype(np.int32)).cuda()\n", - "vtx_col = torch.from_numpy(np.ones((1,3)).astype(np.float32)).cuda()\n", + "col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0], 3)).astype(np.int32)).cuda()\n", + "vtx_col = torch.from_numpy(np.ones((1, 3)).astype(np.float32)).cuda()\n", "# print(\"Mesh has %d triangles and %d vertices.\" % (pos_idx.shape[0], pos.shape[0]))\n", "print(pos_idx.shape, vtx_pos.shape, col_idx.shape, vtx_col.shape)\n", "print(vtx_pos, vtx_col)" @@ -236,13 +262,20 @@ "metadata": {}, "outputs": [], "source": [ - "def render_wrapper(pos,quat):\n", - " rast_target = render(glctx, torch.matmul(mvp, posevec_to_matrix(pos, quat)), vtx_pos, pos_idx, resolution)\n", + "def render_wrapper(pos, quat):\n", + " rast_target = render(\n", + " glctx,\n", + " torch.matmul(mvp, posevec_to_matrix(pos, quat)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " resolution,\n", + " )\n", " return rast_target\n", "\n", + "\n", "def get_viz(rast_target):\n", - " img_target = rast_target[0].detach().cpu().numpy()\n", - " viz = b.get_depth_image(img_target[:,:,0]* 255.0)\n", + " img_target = rast_target[0].detach().cpu().numpy()\n", + " viz = b.get_depth_image(img_target[:, :, 0] * 255.0)\n", " return viz" ] }, @@ -273,12 +306,12 @@ } ], "source": [ - "pos = torch.tensor([0.0, 0.0, -2.5],device=device)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", + "pos = torch.tensor([0.0, 0.0, -2.5], device=device)\n", + "quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", "\n", - "ground_truth_image = render_wrapper(pos,quat)\n", + "ground_truth_image = render_wrapper(pos, quat)\n", "viz_gt = get_viz(ground_truth_image)\n", - "viz_gt\n" + "viz_gt" ] }, { @@ -288,11 +321,13 @@ "metadata": {}, "outputs": [], "source": [ - "pos = torch.tensor([0.0, 0.0, -2.2],device=device, requires_grad=True)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device, requires_grad=True)\n", - "rendered_image = render_wrapper(pos,quat)\n", + "pos = torch.tensor([0.0, 0.0, -2.2], device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " torch.rand(4, device=device) - 0.5, device=device, requires_grad=True\n", + ")\n", + "rendered_image = render_wrapper(pos, quat)\n", "viz = get_viz(rendered_image)\n", - "b.hstack_images([viz, viz_gt])\n" + "b.hstack_images([viz, viz_gt])" ] }, { @@ -302,15 +337,19 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer = torch.optim.Adam([\n", - " {'params': [pos], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 1.0, \"name\": \"quat\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 1.0, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "print(quat)\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " rendered_image = render_wrapper(pos, quat)\n", + " rendered_image = render_wrapper(pos, quat)\n", " loss = torch.abs(ground_truth_image - rendered_image).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", @@ -329,7 +368,7 @@ "metadata": {}, "outputs": [], "source": [ - "diff = (rast_opt - rast_target)**2 # L2 norm.\n", + "diff = (rast_opt - rast_target) ** 2 # L2 norm.\n", "diff.sum()" ] }, @@ -340,33 +379,41 @@ "metadata": {}, "outputs": [], "source": [ - "optimizer = torch.optim.Adam([pose_opt], lr=0.00001)\n", + "optimizer = torch.optim.Adam([pose_opt], lr=0.00001)\n", "images = []\n", "\n", - "for _ in tqdm(range(200)): \n", - " rast_opt = render(glctx, torch.matmul(mvp, quaternion_to_matrix(pose_opt)), vtx_pos, pos_idx, resolution)\n", + "for _ in tqdm(range(200)):\n", + " rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, quaternion_to_matrix(pose_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " resolution,\n", + " )\n", "\n", - " diff = (rast_opt - rast_target)**2 # L2 norm.\n", + " diff = (rast_opt - rast_target) ** 2 # L2 norm.\n", " loss = torch.mean(diff)\n", " loss_val = float(loss)\n", - " \n", + "\n", " if (loss_val < loss_best) and (loss_val > 0.0):\n", " loss_best = loss_val\n", - " \n", + "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", - " print(loss) \n", + " print(loss)\n", " with torch.no_grad():\n", - " pose_opt /= torch.sum(pose_opt**2)**0.5\n", - " \n", - " img_opt = rast_opt[0].detach().cpu().numpy()\n", + " pose_opt /= torch.sum(pose_opt**2) ** 0.5\n", + "\n", + " img_opt = rast_opt[0].detach().cpu().numpy()\n", " images.append(\n", - " b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - " ])\n", + " b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + " )\n", " )" ] }, @@ -377,7 +424,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.vstack_images([images[0],images[-1]])" + "b.vstack_images([images[0], images[-1]])" ] }, { diff --git a/scripts/experiments/gaussian_splatting/splatting.ipynb b/scripts/experiments/gaussian_splatting/splatting.ipynb index 6be25d22..e80311db 100644 --- a/scripts/experiments/gaussian_splatting/splatting.ipynb +++ b/scripts/experiments/gaussian_splatting/splatting.ipynb @@ -7,7 +7,10 @@ "outputs": [], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import torch\n", "import os\n", "import numpy as np\n", @@ -19,6 +22,7 @@ "from random import randint\n", "import pytorch3d.transforms\n", "import jax\n", + "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, @@ -47,18 +51,23 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=100.0, fy=100.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=1.75\n", + " height=100, width=100, fx=100.0, fy=100.0, cx=50.0, cy=50.0, near=0.01, far=1.75\n", ")\n", "fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2\n", "fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2\n", "tan_fovx = math.tan(fovX)\n", "tan_fovy = math.tan(fovY)\n", "\n", - "def render_jax(means3D_jax, opacity_jax, scales_jax, rotations_jax, camera_pose_jax, obs_image_jax, intrinsics):\n", + "\n", + "def render_jax(\n", + " means3D_jax,\n", + " opacity_jax,\n", + " scales_jax,\n", + " rotations_jax,\n", + " camera_pose_jax,\n", + " obs_image_jax,\n", + " intrinsics,\n", + "):\n", " N = means3D_jax.shape[0]\n", "\n", " means3D = torch.tensor(b.utils.jax_to_torch(means3D_jax), requires_grad=True)\n", @@ -66,10 +75,12 @@ " rotations = torch.tensor(b.utils.jax_to_torch(rotations_jax), requires_grad=True)\n", " opacity = torch.tensor(b.utils.jax_to_torch(opacity_jax), requires_grad=True)\n", "\n", - " means2D = torch.tensor(torch.rand((N, 3)),requires_grad=True,device=device)\n", - " \n", + " means2D = torch.tensor(torch.rand((N, 3)), requires_grad=True, device=device)\n", + "\n", " proj_matrix = b.utils.jax_to_torch(b.camera.getProjectionMatrix(intrinsics))\n", - " view_matrix = torch.transpose(torch.tensor(np.array(b.inverse_pose(camera_pose_jax))),0,1).cuda()\n", + " view_matrix = torch.transpose(\n", + " torch.tensor(np.array(b.inverse_pose(camera_pose_jax))), 0, 1\n", + " ).cuda()\n", " raster_settings = GaussianRasterizationSettings(\n", " image_height=int(intrinsics.height),\n", " image_width=int(intrinsics.width),\n", @@ -82,26 +93,29 @@ " sh_degree=0,\n", " campos=torch.zeros(3).cuda(),\n", " prefiltered=False,\n", - " debug=None\n", + " debug=None,\n", " )\n", " rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n", "\n", " gt_rendered_image, radii = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = means3D[:,2:3].repeat(1,3),\n", - " opacities = torch.sigmoid(opacity),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=means3D[:, 2:3].repeat(1, 3),\n", + " opacities=torch.sigmoid(opacity),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", " )\n", " obs_image = b.utils.jax_to_torch(obs_image_jax)\n", " loss = torch.abs(gt_rendered_image - obs_image).mean()\n", " loss.backward()\n", "\n", - "\n", - " return b.utils.torch_to_jax(gt_rendered_image[2,...]), b.utils.torch_to_jax(means3D.grad), b.utils.torch_to_jax(opacity.grad), loss\n", - "\n" + " return (\n", + " b.utils.torch_to_jax(gt_rendered_image[2, ...]),\n", + " b.utils.torch_to_jax(means3D.grad),\n", + " b.utils.torch_to_jax(opacity.grad),\n", + " loss,\n", + " )" ] }, { @@ -126,9 +140,9 @@ ], "source": [ "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(17).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)" + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(17).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)" ] }, { @@ -150,12 +164,22 @@ ], "source": [ "object_pose = b.transform_from_pos(jnp.array([0.0, 0.0, 0.3]))\n", - "camera_poses = [jnp.eye(4), b.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi/40), b.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi/20)]\n", + "camera_poses = [\n", + " jnp.eye(4),\n", + " b.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi / 40),\n", + " b.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi / 20),\n", + "]\n", "\n", - "point_cloud_images = [b.RENDERER.render(b.inverse_pose(cp) @ object_pose[None,...], jnp.array([0])) for cp in camera_poses]\n", - "gt_images = [torch.tensor(np.array(gt_img[...,2]),device=device).detach() for gt_img in point_cloud_images]\n", + "point_cloud_images = [\n", + " b.RENDERER.render(b.inverse_pose(cp) @ object_pose[None, ...], jnp.array([0]))\n", + " for cp in camera_poses\n", + "]\n", + "gt_images = [\n", + " torch.tensor(np.array(gt_img[..., 2]), device=device).detach()\n", + " for gt_img in point_cloud_images\n", + "]\n", "gt_images_stacked = torch.stack(gt_images)\n", - "b.hstack_images([b.get_depth_image(img[:,:,2]) for img in point_cloud_images])" + "b.hstack_images([b.get_depth_image(img[:, :, 2]) for img in point_cloud_images])" ] }, { @@ -192,17 +216,35 @@ } ], "source": [ - "point_cloud_image = point_cloud_images[0][...,:3]\n", - "point_cloud = point_cloud_image.reshape(-1,3)\n", - "point_cloud_not_far = point_cloud[point_cloud[:,2] < intrinsics.far, :]\n", - "means3D = jnp.tile(point_cloud_not_far, (5,1))\n", + "point_cloud_image = point_cloud_images[0][..., :3]\n", + "point_cloud = point_cloud_image.reshape(-1, 3)\n", + "point_cloud_not_far = point_cloud[point_cloud[:, 2] < intrinsics.far, :]\n", + "means3D = jnp.tile(point_cloud_not_far, (5, 1))\n", "N = means3D.shape[0]\n", - "opacity, scales, rotations = jnp.ones((N,1)), jnp.ones((N,3)) - 20.0, jnp.ones((N,4)) \n", + "opacity, scales, rotations = jnp.ones((N, 1)), jnp.ones((N, 3)) - 20.0, jnp.ones((N, 4))\n", "\n", - "imgs = [render_jax(means3D, opacity, scales, rotations, camera_pose, point_cloud_images[0][...,2],intrinsics)[0] for camera_pose in camera_poses ]\n", + "imgs = [\n", + " render_jax(\n", + " means3D,\n", + " opacity,\n", + " scales,\n", + " rotations,\n", + " camera_pose,\n", + " point_cloud_images[0][..., 2],\n", + " intrinsics,\n", + " )[0]\n", + " for camera_pose in camera_poses\n", + "]\n", "b.clear()\n", - "b.show_cloud(\"gt\",b.unproject_depth_jit(point_cloud_images[0][...,2], intrinsics).reshape(-1,3)) \n", - "b.show_cloud(\"reconstruction\",b.unproject_depth_jit(imgs[0], intrinsics).reshape(-1,3),color=b.BLUE) \n", + "b.show_cloud(\n", + " \"gt\",\n", + " b.unproject_depth_jit(point_cloud_images[0][..., 2], intrinsics).reshape(-1, 3),\n", + ")\n", + "b.show_cloud(\n", + " \"reconstruction\",\n", + " b.unproject_depth_jit(imgs[0], intrinsics).reshape(-1, 3),\n", + " color=b.BLUE,\n", + ")\n", "b.hstack_images([b.get_depth_image(d) for d in imgs])" ] }, @@ -336,14 +378,20 @@ ], "source": [ "camera_pose = camera_poses[0]\n", - "gt_img = point_cloud_images[1][...,2]\n", + "gt_img = point_cloud_images[1][..., 2]\n", "for _ in range(100):\n", - " img, grad_means, grad_opacity, loss = render_jax(means3D,opacity, scales, rotations, camera_pose, gt_img,intrinsics)\n", + " img, grad_means, grad_opacity, loss = render_jax(\n", + " means3D, opacity, scales, rotations, camera_pose, gt_img, intrinsics\n", + " )\n", " print(loss)\n", " means3D = means3D - 0.1 * grad_means\n", " opacity = opacity - 0.1 * grad_opacity\n", - "b.show_cloud(\"gt\",b.unproject_depth_jit(gt_img, intrinsics).reshape(-1,3)) \n", - "b.show_cloud(\"reconstruction\",b.unproject_depth_jit(img, intrinsics).reshape(-1,3),color=b.BLUE) " + "b.show_cloud(\"gt\", b.unproject_depth_jit(gt_img, intrinsics).reshape(-1, 3))\n", + "b.show_cloud(\n", + " \"reconstruction\",\n", + " b.unproject_depth_jit(img, intrinsics).reshape(-1, 3),\n", + " color=b.BLUE,\n", + ")" ] }, { @@ -380,22 +428,26 @@ "outputs": [], "source": [ "l = [\n", - " {'params': [means3D], 'lr': 0.01 ,\"name\": \"xyz\"},\n", - " {'params': [opacity], 'lr': 0.05, \"name\": \"opacity\"},\n", - " {'params': [scales], 'lr': 0.05, \"name\": \"scaling\"},\n", - " {'params': [rotations], 'lr': 0.01, \"name\": \"rotation\"}\n", + " {\"params\": [means3D], \"lr\": 0.01, \"name\": \"xyz\"},\n", + " {\"params\": [opacity], \"lr\": 0.05, \"name\": \"opacity\"},\n", + " {\"params\": [scales], \"lr\": 0.05, \"name\": \"scaling\"},\n", + " {\"params\": [rotations], \"lr\": 0.01, \"name\": \"rotation\"},\n", "]\n", "optimizer = torch.optim.SGD(l, lr=0.0)\n", "\n", "pbar = tqdm(range(1000))\n", "for _ in pbar:\n", - " imgs = torch.stack([render(means3D, means2D, opacity, scales, rotations, camera_pose) for camera_pose in camera_poses])\n", + " imgs = torch.stack(\n", + " [\n", + " render(means3D, means2D, opacity, scales, rotations, camera_pose)\n", + " for camera_pose in camera_poses\n", + " ]\n", + " )\n", " loss = torch.abs(gt_images_stacked - imgs).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", - " pbar.set_description(f\"{loss.item()}\")\n", - "\n" + " pbar.set_description(f\"{loss.item()}\")" ] }, { @@ -405,9 +457,18 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"gt\",b.unproject_depth_jit(gt_images_stacked[0].detach().cpu().numpy(), intrinsics).reshape(-1,3)) \n", + "b.show_cloud(\n", + " \"gt\",\n", + " b.unproject_depth_jit(\n", + " gt_images_stacked[0].detach().cpu().numpy(), intrinsics\n", + " ).reshape(-1, 3),\n", + ")\n", "# b.show_cloud(\"means\", means3D.detach().cpu().numpy(),color=b.RED)\n", - "b.show_cloud(\"reconstruction\",b.unproject_depth_jit(convert_to_numpy(imgs[0]), intrinsics).reshape(-1,3),color=b.BLUE) \n", + "b.show_cloud(\n", + " \"reconstruction\",\n", + " b.unproject_depth_jit(convert_to_numpy(imgs[0]), intrinsics).reshape(-1, 3),\n", + " color=b.BLUE,\n", + ")\n", "\n", "b.get_depth_image(convert_to_numpy(imgs[0]))" ] @@ -419,19 +480,23 @@ "outputs": [], "source": [ "img, radii = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = means3D[:,2:3].repeat(1,3),\n", - " opacities = torch.tensor(torch.ones((N, 1)),requires_grad=True,device=device),\n", - " scales = torch.exp(scales),\n", - " rotations = rotations\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=means3D[:, 2:3].repeat(1, 3),\n", + " opacities=torch.tensor(torch.ones((N, 1)), requires_grad=True, device=device),\n", + " scales=torch.exp(scales),\n", + " rotations=rotations,\n", ")\n", - "depth_image = np.moveaxis(img.detach().cpu().numpy(),0,-1)[...,2]\n", + "depth_image = np.moveaxis(img.detach().cpu().numpy(), 0, -1)[..., 2]\n", "b.clear()\n", - "# b.show_cloud(\"gt\",b.unproject_depth_jit(gt_img, intrinsics).reshape(-1,3)) \n", + "# b.show_cloud(\"gt\",b.unproject_depth_jit(gt_img, intrinsics).reshape(-1,3))\n", "# b.show_cloud(\"means\", means3D.detach().cpu().numpy(),color=b.RED)\n", - "b.show_cloud(\"reconstruction\",b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3),color=b.BLUE) \n", + "b.show_cloud(\n", + " \"reconstruction\",\n", + " b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3),\n", + " color=b.BLUE,\n", + ")\n", "\n", "b.get_depth_image(depth_image)" ] @@ -452,14 +517,14 @@ "metadata": {}, "outputs": [], "source": [ - "pos = torch.tensor([0.0, 0.0, 0.5],device=device)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", + "pos = torch.tensor([0.0, 0.0, 0.5], device=device)\n", + "quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", "\n", - "gt_rendered_image = render(pos, quat).detach()\n", - "depth_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "gt_rendered_image = render(pos, quat).detach()\n", + "depth_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz_gt = b.get_depth_image(depth_image)\n", - "viz_gt\n" + "viz_gt" ] }, { @@ -468,11 +533,13 @@ "metadata": {}, "outputs": [], "source": [ - "pos = torch.tensor([0.0, 0.0, 0.5],device=device, requires_grad=True)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device, requires_grad=True)\n", - "rendered_image = render(pos, quat)\n", - "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "pos = torch.tensor([0.0, 0.0, 0.5], device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " torch.rand(4, device=device) - 0.5, device=device, requires_grad=True\n", + ")\n", + "rendered_image = render(pos, quat)\n", + "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz = b.get_depth_image(depth_image)\n", "b.hstack_images([viz, viz_gt])" ] @@ -483,20 +550,23 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [pos], 'lr': 0.001, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.001, \"name\": \"quat\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.001, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.001, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "pbar = tqdm(range(1000))\n", "for _ in pbar:\n", - " rendered_image = render(pos, quat)\n", + " rendered_image = render(pos, quat)\n", " loss = torch.abs(gt_rendered_image - rendered_image).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", - " pbar.set_description(f\"{loss.item()}\")\n" + " pbar.set_description(f\"{loss.item()}\")" ] }, { @@ -505,11 +575,10 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz = b.get_depth_image(depth_image)\n", - "b.hstack_images([viz, viz_gt])\n" + "b.hstack_images([viz, viz_gt])" ] }, { diff --git a/scripts/experiments/gaussian_splatting/splatting_messing_with_it.ipynb b/scripts/experiments/gaussian_splatting/splatting_messing_with_it.ipynb index e08a8cba..a65d2824 100644 --- a/scripts/experiments/gaussian_splatting/splatting_messing_with_it.ipynb +++ b/scripts/experiments/gaussian_splatting/splatting_messing_with_it.ipynb @@ -18,7 +18,10 @@ ], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import torch\n", "import os\n", "import numpy as np\n", @@ -31,8 +34,9 @@ "import pytorch3d.transforms\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm\n", + "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device\n" + "device" ] }, { @@ -75,11 +79,7 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=250.0, fy=250.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=2.5\n", + " height=100, width=100, fx=250.0, fy=250.0, cx=50.0, cy=50.0, near=0.01, far=2.5\n", ")\n", "b.setup_renderer(intrinsics)" ] @@ -90,12 +90,12 @@ "metadata": {}, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(14).rjust(6, '0') + \".ply\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(14).rjust(6, \"0\") + \".ply\")\n", "mesh = b.utils.load_mesh(mesh_path)\n", - "mesh = b.utils.scale_mesh(mesh, 1.0/1000.0)\n", + "mesh = b.utils.scale_mesh(mesh, 1.0 / 1000.0)\n", "b.RENDERER.add_mesh(mesh)\n", - "vertices = torch.tensor(np.array(jnp.array(mesh.vertices)),device=device)" + "vertices = torch.tensor(np.array(jnp.array(mesh.vertices)), device=device)" ] }, { @@ -168,36 +168,53 @@ " P[1, 1] = 2.0 * intrinsics.near / (top - bottom)\n", " P[0, 2] = (right + left) / (right - left)\n", " P[1, 2] = (top + bottom) / (top - bottom)\n", - " P[2, 2] = z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)\n", - " P[2, 3] = -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " P[2, 2] = (\n", + " z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " )\n", + " P[2, 3] = (\n", + " -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " )\n", " P[3, 2] = z_sign\n", " return torch.transpose(P, 0, 1)\n", "\n", + "\n", "proj_matrix = torch.tensor(getProjectionMatrix(intrinsics), device=device)\n", "\n", + "\n", "def posevec_to_matrix(position, quat):\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quat), position.unsqueeze(1)), 1),\n", - " torch.tensor([[0.0, 0.0, 0.0, 1.0]],device=device),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quat),\n", + " position.unsqueeze(1),\n", + " ),\n", + " 1,\n", + " ),\n", + " torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device),\n", " ),\n", " 0,\n", " )\n", + "\n", + "\n", "def apply_transform(points, transform):\n", " rels_ = torch.cat(\n", " (\n", " points,\n", - " torch.ones((points.shape[0], 1), device=device),\n", + " torch.ones((points.shape[0], 1), device=device),\n", " ),\n", " 1,\n", " )\n", - " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[...,:3]\n", + " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[..., :3]\n", + "\n", "\n", "position = torch.tensor([0.0, 0.1, 0.2], device=device)\n", - "quat = torch.tensor([1.0, 0.1, 0.2, 0.3],device=device)\n", + "quat = torch.tensor([1.0, 0.1, 0.2, 0.3], device=device)\n", "\n", "camera_pose = jnp.eye(4)\n", - "view_matrix = torch.transpose(torch.tensor(np.array(b.inverse_pose(camera_pose))),0,1).cuda()\n", + "view_matrix = torch.transpose(\n", + " torch.tensor(np.array(b.inverse_pose(camera_pose))), 0, 1\n", + ").cuda()\n", "raster_settings = GaussianRasterizationSettings(\n", " image_height=int(intrinsics.height),\n", " image_width=int(intrinsics.width),\n", @@ -210,7 +227,7 @@ " sh_degree=1,\n", " campos=torch.zeros(3).cuda(),\n", " prefiltered=False,\n", - " debug=None\n", + " debug=None,\n", ")\n", "rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n", "\n", @@ -218,28 +235,32 @@ "def render(pose, in_color):\n", " means3D = apply_transform(vertices, pose)\n", " N = means3D.shape[0]\n", - " means2D = torch.ones((N, 3),requires_grad=True, device=device)\n", - " opacity = torch.rand((N, 1),requires_grad=True,device=device)\n", - " scales = torch.tensor( 0.005 * torch.rand((N, 3)),requires_grad=True,device=device)\n", - " rotations = torch.rand((N, 4),requires_grad=True,device=device)\n", + " means2D = torch.ones((N, 3), requires_grad=True, device=device)\n", + " opacity = torch.rand((N, 1), requires_grad=True, device=device)\n", + " scales = torch.tensor(0.005 * torch.rand((N, 3)), requires_grad=True, device=device)\n", + " rotations = torch.rand((N, 4), requires_grad=True, device=device)\n", "\n", " data = rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = means3D[:,2:3],\n", - " opacities = opacity,\n", - " scales = scales,\n", - " rotations = rotations,\n", - " in_color=in_color\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=means3D[:, 2:3],\n", + " opacities=opacity,\n", + " scales=scales,\n", + " rotations=rotations,\n", + " in_color=in_color,\n", " )\n", " return data\n", "\n", - "gt_pos = torch.tensor([0.0, 0.0, 0.5],device=device)\n", - "gt_quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", + "\n", + "gt_pos = torch.tensor([0.0, 0.0, 0.5], device=device)\n", + "gt_quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", "gt_pose = posevec_to_matrix(gt_pos, gt_quat)\n", - "gt_color, likelihood, radii = render(gt_pose, torch.zeros((1, int(intrinsics.height), int(intrinsics.width)),device=device))\n", - "viz_gt = b.get_depth_image(gt_color.detach().cpu().numpy()[0,...])\n", + "gt_color, likelihood, radii = render(\n", + " gt_pose,\n", + " torch.zeros((1, int(intrinsics.height), int(intrinsics.width)), device=device),\n", + ")\n", + "viz_gt = b.get_depth_image(gt_color.detach().cpu().numpy()[0, ...])\n", "viz_gt" ] }, @@ -268,9 +289,9 @@ } ], "source": [ - "gt_pose_jnp =jnp.array(gt_pose.detach().cpu().numpy())\n", - "img = b.RENDERER.render(gt_pose_jnp[None,...], jnp.array([0]))\n", - "gt_color = torch.tensor(np.array(img[:,:,2]), device=device)\n", + "gt_pose_jnp = jnp.array(gt_pose.detach().cpu().numpy())\n", + "img = b.RENDERER.render(gt_pose_jnp[None, ...], jnp.array([0]))\n", + "gt_color = torch.tensor(np.array(img[:, :, 2]), device=device)\n", "\n", "viz_gt = b.get_depth_image(gt_color.detach().cpu().numpy())\n", "viz_gt" @@ -283,9 +304,11 @@ "outputs": [], "source": [ "b.clear()\n", - "gt_color_jnp = jnp.array(gt_color.detach().cpu().numpy()[0,...])\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(gt_color_jnp,intrinsics).reshape(-1,3))\n", - "b.show_cloud(\"2\", b.unproject_depth_jit(img[...,2],intrinsics).reshape(-1,3), color=b.RED)" + "gt_color_jnp = jnp.array(gt_color.detach().cpu().numpy()[0, ...])\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(gt_color_jnp, intrinsics).reshape(-1, 3))\n", + "b.show_cloud(\n", + " \"2\", b.unproject_depth_jit(img[..., 2], intrinsics).reshape(-1, 3), color=b.RED\n", + ")" ] }, { @@ -294,7 +317,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_trimesh(\"1\",mesh)\n", + "b.show_trimesh(\"1\", mesh)\n", "b.set_pose(\"1\", gt_pose_jnp)" ] }, @@ -517,24 +540,28 @@ "# pos = torch.tensor([0.0, 0.0, 1.9],device=device, requires_grad=True)\n", "# quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device, requires_grad=True)\n", "\n", - "pos = torch.tensor(gt_pos ,device=device, requires_grad=True)\n", - "quat = torch.tensor(gt_quat ,device=device, requires_grad=True)\n", + "pos = torch.tensor(gt_pos, device=device, requires_grad=True)\n", + "quat = torch.tensor(gt_quat, device=device, requires_grad=True)\n", "\n", "pose = posevec_to_matrix(pos, quat)\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [pos], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.01, \"name\": \"quat\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.01, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "b.clear()\n", "b.show_pose(\"1\", gt_pose.detach().cpu().numpy())\n", "b.show_pose(\"2\", pose.detach().cpu().numpy())\n", "\n", - "gt_color=gt_color.detach()\n", + "gt_color = gt_color.detach()\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", " # print(pos,quat)\n", " pose = posevec_to_matrix(pos, quat)\n", - " rendered_image, likelihood, radii = render(pose, gt_color)\n", + " rendered_image, likelihood, radii = render(pose, gt_color)\n", " loss = torch.abs(gt_color - rendered_image).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", @@ -547,7 +574,7 @@ "b.clear()\n", "b.show_pose(\"1\", gt_pose.detach().cpu().numpy())\n", "b.show_pose(\"2\", pose.detach().cpu().numpy())\n", - "viz = b.get_depth_image(rendered_image.detach().cpu().numpy()[0,...])\n", + "viz = b.get_depth_image(rendered_image.detach().cpu().numpy()[0, ...])\n", "b.hstack_images([viz, viz_gt])" ] }, @@ -558,8 +585,17 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(rendered_image[0,...].detach().cpu().numpy(),intrinsics).reshape(-1,3))\n", - "b.show_cloud(\"2\", b.unproject_depth_jit(gt_color.detach().cpu().numpy(),intrinsics).reshape(-1,3), color=b.RED)" + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth_jit(\n", + " rendered_image[0, ...].detach().cpu().numpy(), intrinsics\n", + " ).reshape(-1, 3),\n", + ")\n", + "b.show_cloud(\n", + " \"2\",\n", + " b.unproject_depth_jit(gt_color.detach().cpu().numpy(), intrinsics).reshape(-1, 3),\n", + " color=b.RED,\n", + ")" ] }, { @@ -659,31 +695,31 @@ "outputs": [], "source": [ "T = 0\n", - "fig = plt.figure(figsize=(6,6))\n", - "ax = fig.add_subplot(2,2,1)\n", + "fig = plt.figure(figsize=(6, 6))\n", + "ax = fig.add_subplot(2, 2, 1)\n", "ax.set_title(\"Target\")\n", - "gt_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "img1 = ax.imshow(b.preprocess_for_viz(gt_image),cmap=b.cmap)\n", - "ax = fig.add_subplot(2,2,2)\n", + "gt_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "img1 = ax.imshow(b.preprocess_for_viz(gt_image), cmap=b.cmap)\n", + "ax = fig.add_subplot(2, 2, 2)\n", "parameters = parameters_over_time[T]\n", "rendered_image = render(*parameters)\n", - "rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "img2 = ax.imshow(b.preprocess_for_viz(rendered_image),cmap=b.cmap)\n", + "rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "img2 = ax.imshow(b.preprocess_for_viz(rendered_image), cmap=b.cmap)\n", "title = ax.set_title(f\"Reconstruction\")\n", - "ax = fig.add_subplot(2,1,2)\n", + "ax = fig.add_subplot(2, 1, 2)\n", "line = ax.plot(jnp.arange(T), losses_over_time[:T])\n", "# ax.set_yscale(\"log\")\n", "ax.set_title(\"Pixelwise MSE Loss\")\n", "# ax.set_ylim(0.01, 1000.0)\n", "ax.set_xlabel(\"Iteration\")\n", - "ax.set_xlim(0,len(losses_over_time))\n", + "ax.set_xlim(0, len(losses_over_time))\n", "fig.tight_layout()\n", "\n", "buffs = []\n", - "for T in tqdm(range(0,len(losses_over_time),3)):\n", + "for T in tqdm(range(0, len(losses_over_time), 3)):\n", " parameters = parameters_over_time[T]\n", " rendered_image = render(*parameters)\n", - " rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", + " rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", " img2.set_array(b.preprocess_for_viz(rendered_image))\n", " line[0].set_xdata(jnp.arange(T))\n", " line[0].set_ydata(losses_over_time[:T])\n", diff --git a/scripts/experiments/gaussian_splatting/splatting_optim.ipynb b/scripts/experiments/gaussian_splatting/splatting_optim.ipynb index 57d5b8da..01fbf744 100644 --- a/scripts/experiments/gaussian_splatting/splatting_optim.ipynb +++ b/scripts/experiments/gaussian_splatting/splatting_optim.ipynb @@ -18,7 +18,10 @@ ], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import torch\n", "import os\n", "import numpy as np\n", @@ -29,6 +32,7 @@ "import jax.numpy as jnp\n", "from random import randint\n", "import pytorch3d.transforms\n", + "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device" ] @@ -57,10 +61,10 @@ "metadata": {}, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(14).rjust(6, '0') + \".ply\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(14).rjust(6, \"0\") + \".ply\")\n", "mesh = b.utils.load_mesh(mesh_path)\n", - "vertices = torch.tensor(np.array(jnp.array(mesh.vertices) / 1000.0),device=device)" + "vertices = torch.tensor(np.array(jnp.array(mesh.vertices) / 1000.0), device=device)" ] }, { @@ -79,11 +83,7 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=300.0, fy=300.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.01, far=2.5\n", + " height=200, width=200, fx=300.0, fy=300.0, cx=100.0, cy=100.0, near=0.01, far=2.5\n", ")\n", "fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2\n", "fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2\n", @@ -105,11 +105,16 @@ " P[1, 1] = 2.0 * intrinsics.near / (top - bottom)\n", " P[0, 2] = (right + left) / (right - left)\n", " P[1, 2] = (top + bottom) / (top - bottom)\n", - " P[2, 2] = z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)\n", - " P[2, 3] = -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " P[2, 2] = (\n", + " z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " )\n", + " P[2, 3] = (\n", + " -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)\n", + " )\n", " P[3, 2] = z_sign\n", " return torch.transpose(P, 0, 1)\n", "\n", + "\n", "proj_matrix = torch.tensor(getProjectionMatrix(intrinsics), device=device)" ] }, @@ -134,23 +139,33 @@ "def posevec_to_matrix(position, quat):\n", " return torch.cat(\n", " (\n", - " torch.cat((pytorch3d.transforms.quaternion_to_matrix(quat), position.unsqueeze(1)), 1),\n", - " torch.tensor([[0.0, 0.0, 0.0, 1.0]],device=device),\n", + " torch.cat(\n", + " (\n", + " pytorch3d.transforms.quaternion_to_matrix(quat),\n", + " position.unsqueeze(1),\n", + " ),\n", + " 1,\n", + " ),\n", + " torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device),\n", " ),\n", " 0,\n", " )\n", + "\n", + "\n", "def apply_transform(points, transform):\n", " rels_ = torch.cat(\n", " (\n", " points,\n", - " torch.ones((points.shape[0], 1), device=device),\n", + " torch.ones((points.shape[0], 1), device=device),\n", " ),\n", " 1,\n", " )\n", - " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[...,:3]\n", + " return torch.einsum(\"ij, aj -> ai\", transform, rels_)[..., :3]\n", + "\n", + "\n", "position = torch.tensor([0.0, 0.1, 0.2], device=device)\n", - "quat = torch.tensor([1.0, 0.1, 0.2, 0.3],device=device)\n", - "points = torch.zeros((5,3), device = device)\n", + "quat = torch.tensor([1.0, 0.1, 0.2, 0.3], device=device)\n", + "points = torch.zeros((5, 3), device=device)\n", "print(apply_transform(points, posevec_to_matrix(position, quat)))" ] }, @@ -161,7 +176,9 @@ "outputs": [], "source": [ "camera_pose = jnp.eye(4)\n", - "view_matrix = torch.transpose(torch.tensor(np.array(b.inverse_pose(camera_pose))),0,1).cuda()\n", + "view_matrix = torch.transpose(\n", + " torch.tensor(np.array(b.inverse_pose(camera_pose))), 0, 1\n", + ").cuda()\n", "raster_settings = GaussianRasterizationSettings(\n", " image_height=int(intrinsics.height),\n", " image_width=int(intrinsics.width),\n", @@ -174,9 +191,9 @@ " sh_degree=1,\n", " campos=torch.zeros(3).cuda(),\n", " prefiltered=False,\n", - " debug=None\n", + " debug=None,\n", ")\n", - "rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n" + "rasterizer = GaussianRasterizer(raster_settings=raster_settings)" ] }, { @@ -192,22 +209,24 @@ "metadata": {}, "outputs": [], "source": [ - "def render(pos,quat):\n", + "def render(pos, quat):\n", " means3D = apply_transform(vertices, posevec_to_matrix(pos, quat))\n", " N = means3D.shape[0]\n", - " means2D = torch.ones((N, 3),requires_grad=True, device=device)\n", - " opacity = torch.ones((N, 1),requires_grad=True,device=device)\n", - " scales = torch.tensor( 0.0025 * torch.rand((N, 3)),requires_grad=True,device=device)\n", - " rotations = torch.rand((N, 4),requires_grad=True,device=device)\n", + " means2D = torch.ones((N, 3), requires_grad=True, device=device)\n", + " opacity = torch.ones((N, 1), requires_grad=True, device=device)\n", + " scales = torch.tensor(\n", + " 0.0025 * torch.rand((N, 3)), requires_grad=True, device=device\n", + " )\n", + " rotations = torch.rand((N, 4), requires_grad=True, device=device)\n", "\n", - " color, radii= rasterizer(\n", - " means3D = means3D,\n", - " means2D = means2D,\n", - " shs = None,\n", - " colors_precomp = means3D[:,2:3].repeat(1,3),\n", - " opacities = opacity,\n", - " scales = scales,\n", - " rotations = rotations\n", + " color, radii = rasterizer(\n", + " means3D=means3D,\n", + " means2D=means2D,\n", + " shs=None,\n", + " colors_precomp=means3D[:, 2:3].repeat(1, 3),\n", + " opacities=opacity,\n", + " scales=scales,\n", + " rotations=rotations,\n", " )\n", " return color" ] @@ -242,24 +261,26 @@ } ], "source": [ - "pos = torch.tensor([0.0, 0.0, 0.5],device=device)\n", - "quat = torch.tensor(torch.rand(4,device=device) - 0.5,device=device)\n", + "pos = torch.tensor([0.0, 0.0, 0.5], device=device)\n", + "quat = torch.tensor(torch.rand(4, device=device) - 0.5, device=device)\n", "\n", - "gt_rendered_image = render(pos, quat).detach()\n", - "depth_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "gt_rendered_image = render(pos, quat).detach()\n", + "depth_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz_gt = b.get_depth_image(depth_image)\n", "viz_gt\n", "\n", - "pos = torch.tensor([0.0, 0.0, 0.5],device=device, requires_grad=True)\n", - "quat = torch.tensor(quat + torch.rand(4,device=device)*0.4,device=device, requires_grad=True)\n", - "rendered_image = render(pos, quat)\n", - "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "pos = torch.tensor([0.0, 0.0, 0.5], device=device, requires_grad=True)\n", + "quat = torch.tensor(\n", + " quat + torch.rand(4, device=device) * 0.4, device=device, requires_grad=True\n", + ")\n", + "rendered_image = render(pos, quat)\n", + "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz = b.get_depth_image(depth_image)\n", "parameters_over_time = []\n", "losses_over_time = []\n", - "b.hstack_images([viz, viz_gt])\n" + "b.hstack_images([viz, viz_gt])" ] }, { @@ -268,7 +289,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))" + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))" ] }, { @@ -337,29 +358,31 @@ } ], "source": [ - "\n", - "optimizer = torch.optim.Adam([\n", - " {'params': [pos], 'lr': 0.01, \"name\": \"pos\"},\n", - " {'params': [quat], 'lr': 0.1, \"name\": \"quat\"},\n", - "], lr=0.0, eps=1e-15)\n", + "optimizer = torch.optim.Adam(\n", + " [\n", + " {\"params\": [pos], \"lr\": 0.01, \"name\": \"pos\"},\n", + " {\"params\": [quat], \"lr\": 0.1, \"name\": \"quat\"},\n", + " ],\n", + " lr=0.0,\n", + " eps=1e-15,\n", + ")\n", "\n", "\n", "pbar = tqdm(range(100))\n", "for _ in pbar:\n", - " rendered_image = render(pos, quat)\n", + " rendered_image = render(pos, quat)\n", " loss = torch.abs(gt_rendered_image - rendered_image).mean()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", - " parameters_over_time.append((pos.detach().clone(),quat.detach().clone()))\n", + " parameters_over_time.append((pos.detach().clone(), quat.detach().clone()))\n", " losses_over_time.append(loss.item())\n", " pbar.set_description(f\"{loss.item()}\")\n", "\n", - "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1,3))\n", + "depth_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "b.show_cloud(\"1\", b.unproject_depth_jit(depth_image, intrinsics).reshape(-1, 3))\n", "viz = b.get_depth_image(depth_image)\n", - "b.hstack_images([viz, viz_gt])\n", - "\n" + "b.hstack_images([viz, viz_gt])" ] }, { @@ -397,31 +420,31 @@ ], "source": [ "T = 0\n", - "fig = plt.figure(figsize=(6,6))\n", - "ax = fig.add_subplot(2,2,1)\n", + "fig = plt.figure(figsize=(6, 6))\n", + "ax = fig.add_subplot(2, 2, 1)\n", "ax.set_title(\"Target\")\n", - "gt_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "img1 = ax.imshow(b.preprocess_for_viz(gt_image),cmap=b.cmap)\n", - "ax = fig.add_subplot(2,2,2)\n", + "gt_image = np.moveaxis(gt_rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "img1 = ax.imshow(b.preprocess_for_viz(gt_image), cmap=b.cmap)\n", + "ax = fig.add_subplot(2, 2, 2)\n", "parameters = parameters_over_time[T]\n", "rendered_image = render(*parameters)\n", - "rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", - "img2 = ax.imshow(b.preprocess_for_viz(rendered_image),cmap=b.cmap)\n", + "rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", + "img2 = ax.imshow(b.preprocess_for_viz(rendered_image), cmap=b.cmap)\n", "title = ax.set_title(f\"Reconstruction\")\n", - "ax = fig.add_subplot(2,1,2)\n", + "ax = fig.add_subplot(2, 1, 2)\n", "line = ax.plot(jnp.arange(T), losses_over_time[:T])\n", "# ax.set_yscale(\"log\")\n", "ax.set_title(\"Pixelwise MSE Loss\")\n", "ax.set_ylim(-0.0001, 0.1)\n", "ax.set_xlabel(\"Iteration\")\n", - "ax.set_xlim(0,len(losses_over_time))\n", + "ax.set_xlim(0, len(losses_over_time))\n", "fig.tight_layout()\n", "\n", "buffs = []\n", - "for T in tqdm(range(0,len(losses_over_time),5)):\n", + "for T in tqdm(range(0, len(losses_over_time), 5)):\n", " parameters = parameters_over_time[T]\n", " rendered_image = render(*parameters)\n", - " rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(),0,-1)[...,2]\n", + " rendered_image = np.moveaxis(rendered_image.detach().cpu().numpy(), 0, -1)[..., 2]\n", " img2.set_array(b.preprocess_for_viz(rendered_image))\n", " line[0].set_xdata(jnp.arange(T))\n", " line[0].set_ydata(losses_over_time[:T])\n", diff --git a/scripts/experiments/gaussian_splatting/splatting_simple.ipynb b/scripts/experiments/gaussian_splatting/splatting_simple.ipynb index 7a98e08c..144e870b 100644 --- a/scripts/experiments/gaussian_splatting/splatting_simple.ipynb +++ b/scripts/experiments/gaussian_splatting/splatting_simple.ipynb @@ -51,11 +51,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=500.0, fy=500.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=50.0\n", + " height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=50.0\n", ")" ] }, @@ -74,7 +70,9 @@ } ], "source": [ - "mesh =b.utils.load_mesh(os.path.join(b.utils.get_assets_dir(), \"sample_objs/icosahedron.obj\"))\n", + "mesh = b.utils.load_mesh(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/icosahedron.obj\")\n", + ")\n", "print(mesh.vertices.shape)" ] }, @@ -88,38 +86,50 @@ "def process_ball(position, quaternion, scaling, opacity):\n", " S = jnp.eye(3) * scaling\n", " R = b.quaternion_to_rotation_matrix(quaternion)\n", - " cov3d = R @ S @ S.transpose() @ R.transpose()\n", - " J = jnp.array([\n", - " [intrinsics.fx / position[2], 0.0, -(intrinsics.fx * position[0]) / (position[2]**2)], \n", - " [0.0, intrinsics.fx / position[2], -(intrinsics.fy * position[1]) / (position[2]**2)], \n", - " [0.0, 0.0, 0.0],\n", - " ])\n", + " cov3d = R @ S @ S.transpose() @ R.transpose()\n", + " J = jnp.array(\n", + " [\n", + " [\n", + " intrinsics.fx / position[2],\n", + " 0.0,\n", + " -(intrinsics.fx * position[0]) / (position[2] ** 2),\n", + " ],\n", + " [\n", + " 0.0,\n", + " intrinsics.fx / position[2],\n", + " -(intrinsics.fy * position[1]) / (position[2] ** 2),\n", + " ],\n", + " [0.0, 0.0, 0.0],\n", + " ]\n", + " )\n", " T = J\n", - " cov2d = (J @ cov3d @ J.transpose())[:2,:2]\n", + " cov2d = (J @ cov3d @ J.transpose())[:2, :2]\n", " cov2d_inv = jnp.linalg.inv(cov2d)\n", - " cov2ddet = (cov2d[0,0]*cov2d[1,1] - cov2d[0,1]**2)\n", - " ball_xy = position[:2] / position[2] * jnp.array([intrinsics.fx, intrinsics.fy]) + jnp.array([intrinsics.cx, intrinsics.cy]) \n", + " cov2ddet = cov2d[0, 0] * cov2d[1, 1] - cov2d[0, 1] ** 2\n", + " ball_xy = position[:2] / position[2] * jnp.array(\n", + " [intrinsics.fx, intrinsics.fy]\n", + " ) + jnp.array([intrinsics.cx, intrinsics.cy])\n", " return (\n", - " jnp.array([cov2d_inv[0,0], cov2d_inv[0,1], cov2d_inv[1,1]]),\n", + " jnp.array([cov2d_inv[0, 0], cov2d_inv[0, 1], cov2d_inv[1, 1]]),\n", " cov2ddet,\n", " jnp.array([ball_xy[0], ball_xy[1]]),\n", " position[2],\n", " opacity,\n", " )\n", "\n", + "\n", "@functools.partial(\n", " jnp.vectorize,\n", - " signature='(m)->()',\n", - " excluded=(1,2,3,4,5,),\n", + " signature=\"(m)->()\",\n", + " excluded=(\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 4,\n", + " 5,\n", + " ),\n", ")\n", - "def get_value_at_pixel(\n", - " ij,\n", - " all_cov,\n", - " all_cov_det,\n", - " all_pixel,\n", - " all_z,\n", - " all_opacity\n", - "):\n", + "def get_value_at_pixel(ij, all_cov, all_cov_det, all_pixel, all_z, all_opacity):\n", " T = 1.0\n", " far = intrinsics.far\n", " running_value = 0.0\n", @@ -130,28 +140,40 @@ " depth_val = all_z[idx]\n", " opacity = all_opacity[idx]\n", " d = jnp.array([ball_x, ball_y]) - ij\n", - " power = -0.5 * (cov_x * d[0]**2 + cov_z * d[1]**2 + 2 * cov_y * d[0] * d[1])\n", + " power = -0.5 * (cov_x * d[0] ** 2 + cov_z * d[1] ** 2 + 2 * cov_y * d[0] * d[1])\n", " alpha = jnp.minimum(0.99, opacity * jnp.exp(power))\n", " test_T = T * (1 - alpha)\n", - " alpha_threshold = 1.0/2055.0\n", - " running_value = jnp.where(jnp.logical_or(jnp.logical_or(power > 0.0, alpha < alpha_threshold), T < 0.0001), running_value, running_value + depth_val * alpha * T)\n", - " T = jnp.where(jnp.logical_or(jnp.logical_or(power > 0.0, alpha < 1.0/255.0), T < 0.0001), T, test_T)\n", + " alpha_threshold = 1.0 / 2055.0\n", + " running_value = jnp.where(\n", + " jnp.logical_or(\n", + " jnp.logical_or(power > 0.0, alpha < alpha_threshold), T < 0.0001\n", + " ),\n", + " running_value,\n", + " running_value + depth_val * alpha * T,\n", + " )\n", + " T = jnp.where(\n", + " jnp.logical_or(\n", + " jnp.logical_or(power > 0.0, alpha < 1.0 / 255.0), T < 0.0001\n", + " ),\n", + " T,\n", + " test_T,\n", + " )\n", " running_value += T * far\n", " return running_value\n", "\n", + "\n", "def render(positions, quaternions, scalings, opacitys):\n", - " order = jnp.argsort(positions[:,2])\n", + " order = jnp.argsort(positions[:, 2])\n", " processed_ball_data = jax.vmap(process_ball)(\n", - " positions[order],\n", - " quaternions[order],\n", - " scalings[order],\n", - " opacitys[order]\n", + " positions[order], quaternions[order], scalings[order], opacitys[order]\n", " )\n", " jj, ii = jnp.meshgrid(jnp.arange(intrinsics.width), jnp.arange(intrinsics.height))\n", - " pixel_indices = jnp.stack([jj,ii],axis=-1)\n", + " pixel_indices = jnp.stack([jj, ii], axis=-1)\n", "\n", " image = get_value_at_pixel(pixel_indices, *processed_ball_data)\n", " return image\n", + "\n", + "\n", "render_jit = jax.jit(render)" ] }, @@ -174,16 +196,23 @@ } ], "source": [ - "positions = jnp.array([\n", - " [0.1, -0.4, 12.5],\n", - " [1.1, 0.4, 14.6],\n", - "])\n", + "positions = jnp.array(\n", + " [\n", + " [0.1, -0.4, 12.5],\n", + " [1.1, 0.4, 14.6],\n", + " ]\n", + ")\n", "quaternions = jnp.ones((positions.shape[0], 4))\n", - "scalings = jnp.array([\n", - " [1.0, 1.0, 1.0],\n", - " [1.0, 1.0, 1.0],\n", - " [1.0, 1.0, 1.0],\n", - "]) * 0.4\n", + "scalings = (\n", + " jnp.array(\n", + " [\n", + " [1.0, 1.0, 1.0],\n", + " [1.0, 1.0, 1.0],\n", + " [1.0, 1.0, 1.0],\n", + " ]\n", + " )\n", + " * 0.4\n", + ")\n", "opacitys = jnp.array([1.0, 1.0, 0.5])\n", "image = render_jit(positions, quaternions, scalings, opacitys)\n", "b.get_depth_image(image).convert(\"RGB\")" @@ -196,7 +225,12 @@ "metadata": {}, "outputs": [], "source": [ - "shape_model = jax.vmap(lambda x: jnp.array([jnp.cos(x), jnp.sin(x), 0.0]))(jnp.linspace(0.0, 2*jnp.pi, 100)) * 0.1" + "shape_model = (\n", + " jax.vmap(lambda x: jnp.array([jnp.cos(x), jnp.sin(x), 0.0]))(\n", + " jnp.linspace(0.0, 2 * jnp.pi, 100)\n", + " )\n", + " * 0.1\n", + ")" ] }, { @@ -212,16 +246,36 @@ "except Exception:\n", " pass\n", "\n", - "def render_from_pos_quat(pos,quat):\n", - " image = render(b.apply_transform_jit(shape_model, \n", - " b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), pos) \n", - " ), quaternions, scalings, opacitys)\n", + "\n", + "def render_from_pos_quat(pos, quat):\n", + " image = render(\n", + " b.apply_transform_jit(\n", + " shape_model,\n", + " b.transform_from_rot_and_pos(b.quaternion_to_rotation_matrix(quat), pos),\n", + " ),\n", + " quaternions,\n", + " scalings,\n", + " opacitys,\n", + " )\n", " return image\n", + "\n", + "\n", "render_from_pos_quat_jit = jax.jit(render_from_pos_quat)\n", "\n", - "def loss(pos,quat, gt_image):\n", - " return jnp.mean((render_from_pos_quat(pos,quat) - gt_image)**2)\n", - "value_and_grad_loss = jax.jit(jax.value_and_grad(loss, argnums=(0,1,)))" + "\n", + "def loss(pos, quat, gt_image):\n", + " return jnp.mean((render_from_pos_quat(pos, quat) - gt_image) ** 2)\n", + "\n", + "\n", + "value_and_grad_loss = jax.jit(\n", + " jax.value_and_grad(\n", + " loss,\n", + " argnums=(\n", + " 0,\n", + " 1,\n", + " ),\n", + " )\n", + ")" ] }, { @@ -264,12 +318,12 @@ "# random_positions = jax.random.uniform(jax.random.PRNGKey(10),(50,3), minval=-1.0, maxval=1.0) * 0.1\n", "N = shape_model.shape[0]\n", "quaternions = jnp.ones((N, 4))\n", - "scalings = jnp.ones((N, 3))*0.01\n", + "scalings = jnp.ones((N, 3)) * 0.01\n", "opacitys = jnp.ones(N)\n", "\n", "pose = b.transform_from_pos(jnp.array([0.0, 0.0, 2.0]))\n", - "pos,quat = pose[:3,3], b.rotation_matrix_to_quaternion(pose[:3,:3])\n", - "gt_image = render_from_pos_quat_jit(pos,quat)\n", + "pos, quat = pose[:3, 3], b.rotation_matrix_to_quaternion(pose[:3, :3])\n", + "gt_image = render_from_pos_quat_jit(pos, quat)\n", "b.get_depth_image(gt_image)" ] }, @@ -292,14 +346,9 @@ } ], "source": [ - "random_pose = b.distributions.gaussian_vmf_jit(\n", - " jax.random.PRNGKey(10),\n", - " pose,\n", - " 0.04,\n", - " 10.0\n", - ")\n", - "pos,quat = random_pose[:3,3], b.rotation_matrix_to_quaternion(random_pose[:3,:3])\n", - "image = render_from_pos_quat_jit(pos,quat)\n", + "random_pose = b.distributions.gaussian_vmf_jit(jax.random.PRNGKey(10), pose, 0.04, 10.0)\n", + "pos, quat = random_pose[:3, 3], b.rotation_matrix_to_quaternion(random_pose[:3, :3])\n", + "image = render_from_pos_quat_jit(pos, quat)\n", "b.get_depth_image(image)" ] }, @@ -333,14 +382,14 @@ "learning_rate_quat = 0.0001\n", "pbar = tqdm(range(10000))\n", "for _ in pbar:\n", - " value, grads = value_and_grad_loss(pos,quat,gt_image)\n", - " pos = pos - grads[0]*learning_rate_pos\n", - " quat = quat - grads[1]*learning_rate_quat\n", - " parameters_over_time.append((pos,quat))\n", + " value, grads = value_and_grad_loss(pos, quat, gt_image)\n", + " pos = pos - grads[0] * learning_rate_pos\n", + " quat = quat - grads[1] * learning_rate_quat\n", + " parameters_over_time.append((pos, quat))\n", " losses_over_time.append(value)\n", " pbar.set_description(\"Loss: {}\".format(value))\n", "print(value)\n", - "image = render_from_pos_quat_jit(pos,quat)\n", + "image = render_from_pos_quat_jit(pos, quat)\n", "b.get_depth_image(image).convert(\"RGB\")" ] }, @@ -363,21 +412,23 @@ ], "source": [ "T = 0\n", - "fig = plt.figure(figsize=(6,6))\n", - "ax = fig.add_subplot(2,2,1)\n", + "fig = plt.figure(figsize=(6, 6))\n", + "ax = fig.add_subplot(2, 2, 1)\n", "ax.set_title(\"Target\")\n", - "img1 = ax.imshow(b.preprocess_for_viz(gt_image),cmap=b.cmap)\n", - "ax = fig.add_subplot(2,2,2)\n", + "img1 = ax.imshow(b.preprocess_for_viz(gt_image), cmap=b.cmap)\n", + "ax = fig.add_subplot(2, 2, 2)\n", "parameters = parameters_over_time[T]\n", - "img2 = ax.imshow(b.preprocess_for_viz(render_from_pos_quat_jit(*parameters)),cmap=b.cmap)\n", + "img2 = ax.imshow(\n", + " b.preprocess_for_viz(render_from_pos_quat_jit(*parameters)), cmap=b.cmap\n", + ")\n", "title = ax.set_title(f\"Reconstruction\")\n", - "ax = fig.add_subplot(2,1,2)\n", + "ax = fig.add_subplot(2, 1, 2)\n", "line = ax.plot(jnp.arange(T), losses_over_time[:T])\n", "ax.set_yscale(\"log\")\n", "ax.set_title(\"Pixelwise MSE Loss\")\n", "ax.set_ylim(0.01, 1000.0)\n", "ax.set_xlabel(\"Iteration\")\n", - "ax.set_xlim(0,len(losses_over_time))\n", + "ax.set_xlim(0, len(losses_over_time))\n", "fig.tight_layout()" ] }, @@ -397,7 +448,7 @@ ], "source": [ "buffs = []\n", - "for T in tqdm(range(0,len(losses_over_time),45)):\n", + "for T in tqdm(range(0, len(losses_over_time), 45)):\n", " parameters = parameters_over_time[T]\n", " img2.set_array(b.preprocess_for_viz(render_from_pos_quat_jit(*parameters)))\n", " line[0].set_xdata(jnp.arange(T))\n", @@ -447,10 +498,13 @@ "metadata": {}, "outputs": [], "source": [ - "b.make_gif_from_pil_images([\n", - " b.get_depth_image(render_from_pos_quat_jit(*parameters)).convert(\"RGB\")\n", - " for parameters in parameters_over_time[::2]\n", - "],\"optimization.gif\")" + "b.make_gif_from_pil_images(\n", + " [\n", + " b.get_depth_image(render_from_pos_quat_jit(*parameters)).convert(\"RGB\")\n", + " for parameters in parameters_over_time[::2]\n", + " ],\n", + " \"optimization.gif\",\n", + ")" ] }, { @@ -610,7 +664,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "image = get_value_at_pixel(pixel_indices, *processed_ball_data)\n", "b.get_depth_image(image)" ] @@ -655,32 +708,41 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "@functools.partial(\n", " jnp.vectorize,\n", - " signature='(m)->()',\n", + " signature=\"(m)->()\",\n", " excluded=(1,),\n", ")\n", - "def get_value_at_pixel(\n", - " ij,\n", - " sorted_processed_ball_data\n", - "):\n", + "def get_value_at_pixel(ij, sorted_processed_ball_data):\n", " T = 1.0\n", " far = intrinsics.far\n", " running_value = 0.0\n", " for ball_data in sorted_processed_ball_data:\n", " cov_x, cov_y, cov_z, cov_det, ball_x, ball_y, depth_val, opacity = ball_data\n", " d = jnp.array([ball_x, ball_y]) - ij\n", - " \n", - " power = -0.5 * (cov_x * d[0]**2 + cov_z * d[1]**2 + 2 * cov_y * d[0] * d[1])\n", + "\n", + " power = -0.5 * (cov_x * d[0] ** 2 + cov_z * d[1] ** 2 + 2 * cov_y * d[0] * d[1])\n", " alpha = jnp.minimum(0.99, opacity * jnp.exp(power))\n", " test_T = T * (1 - alpha)\n", "\n", - " running_value = jnp.where(jnp.logical_or(jnp.logical_or(power > 0.0, alpha < 1.0/255.0), T < 0.0001), running_value, running_value + depth_val * alpha * T)\n", - " T = jnp.where(jnp.logical_or(jnp.logical_or(power > 0.0, alpha < 1.0/255.0), T < 0.0001), T, test_T)\n", + " running_value = jnp.where(\n", + " jnp.logical_or(\n", + " jnp.logical_or(power > 0.0, alpha < 1.0 / 255.0), T < 0.0001\n", + " ),\n", + " running_value,\n", + " running_value + depth_val * alpha * T,\n", + " )\n", + " T = jnp.where(\n", + " jnp.logical_or(\n", + " jnp.logical_or(power > 0.0, alpha < 1.0 / 255.0), T < 0.0001\n", + " ),\n", + " T,\n", + " test_T,\n", + " )\n", " running_value += T * far\n", " return running_value\n", "\n", + "\n", "def process_ball(ball_parameters):\n", " position = ball_parameters[:3]\n", " quat = ball_parameters[3:7]\n", @@ -688,37 +750,57 @@ " opacity = ball_parameters[10]\n", " S = jnp.eye(3) * scaling\n", " R = b.quaternion_to_rotation_matrix(quat)\n", - " cov3d = R @ S @ S.transpose() @ R.transpose()\n", - " \n", - " J = jnp.array([\n", - " [intrinsics.fx / position[2], 0.0, -(intrinsics.fx * position[0]) / (position[2]**2)], \n", - " [0.0, intrinsics.fx / position[2], -(intrinsics.fy * position[1]) / (position[2]**2)], \n", - " [0.0, 0.0, 0.0],\n", - " ])\n", + " cov3d = R @ S @ S.transpose() @ R.transpose()\n", + "\n", + " J = jnp.array(\n", + " [\n", + " [\n", + " intrinsics.fx / position[2],\n", + " 0.0,\n", + " -(intrinsics.fx * position[0]) / (position[2] ** 2),\n", + " ],\n", + " [\n", + " 0.0,\n", + " intrinsics.fx / position[2],\n", + " -(intrinsics.fy * position[1]) / (position[2] ** 2),\n", + " ],\n", + " [0.0, 0.0, 0.0],\n", + " ]\n", + " )\n", " T = J\n", - " cov2d = (J @ cov3d @ J.transpose())[:2,:2]\n", + " cov2d = (J @ cov3d @ J.transpose())[:2, :2]\n", " cov2d_inv = jnp.linalg.inv(cov2d)\n", - " cov2ddet = (cov2d[0,0]*cov2d[1,1] - cov2d[0,1]**2)\n", - " ball_xy = ball_parameters[:2] / ball_parameters[2] * jnp.array([intrinsics.fx, intrinsics.fy]) + jnp.array([intrinsics.cx, intrinsics.cy]) \n", - " return jnp.array([\n", - " cov2d_inv[0,0], cov2d_inv[0,1], cov2d_inv[1,1],\n", - " cov2ddet,\n", - " ball_xy[0], ball_xy[1],\n", - " position[2],\n", - " opacity,\n", - " ])\n", + " cov2ddet = cov2d[0, 0] * cov2d[1, 1] - cov2d[0, 1] ** 2\n", + " ball_xy = ball_parameters[:2] / ball_parameters[2] * jnp.array(\n", + " [intrinsics.fx, intrinsics.fy]\n", + " ) + jnp.array([intrinsics.cx, intrinsics.cy])\n", + " return jnp.array(\n", + " [\n", + " cov2d_inv[0, 0],\n", + " cov2d_inv[0, 1],\n", + " cov2d_inv[1, 1],\n", + " cov2ddet,\n", + " ball_xy[0],\n", + " ball_xy[1],\n", + " position[2],\n", + " opacity,\n", + " ]\n", + " )\n", "\n", "\n", "def render_ball(ball_parameters):\n", " processed_ball_data = jax.vmap(process_ball)(ball_parameters)\n", " jj, ii = jnp.meshgrid(jnp.arange(intrinsics.width), jnp.arange(intrinsics.height))\n", - " pixel_indices = jnp.stack([jj,ii],axis=-1)\n", + " pixel_indices = jnp.stack([jj, ii], axis=-1)\n", " # sorted_processed_ball_data = jax.lax.sort_key_val(processed_ball_data[:,6], processed_ball_data)\n", " return get_value_at_pixel(pixel_indices, processed_ball_data)\n", "\n", + "\n", "def loss(ball_parameters, gt_depth):\n", " alphas = render_ball(ball_parameters)\n", - " return ((alphas - gt_depth)**2).mean()\n", + " return ((alphas - gt_depth) ** 2).mean()\n", + "\n", + "\n", "value_and_grad_loss = jax.jit(jax.value_and_grad(loss))\n", "\n", "\n", @@ -763,7 +845,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_depth_image(gt_depth,min=0.0,max=intrinsics.far)" + "b.get_depth_image(gt_depth, min=0.0, max=intrinsics.far)" ] }, { @@ -773,14 +855,16 @@ "metadata": {}, "outputs": [], "source": [ - "ball_parameters = jnp.array([\n", - " [0.0, 0.0, 0.2, 1.0, 0.0, 0.0, 0.0, 0.012, 0.002, 0.002, 0.2],\n", - " [0.0, 0.1, 0.4, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 0.3],\n", - " [0.1, 0.1, 0.5, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", - " [0.0, 0.1, 0.6, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", - " [0.1, 0.1, 0.7, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", - " [0.1, 0.1, 0.8, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", - "])\n", + "ball_parameters = jnp.array(\n", + " [\n", + " [0.0, 0.0, 0.2, 1.0, 0.0, 0.0, 0.0, 0.012, 0.002, 0.002, 0.2],\n", + " [0.0, 0.1, 0.4, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 0.3],\n", + " [0.1, 0.1, 0.5, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", + " [0.0, 0.1, 0.6, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", + " [0.1, 0.1, 0.7, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", + " [0.1, 0.1, 0.8, 1.0, 0.0, 0.0, 0.0, 0.002, 0.002, 0.002, 1.0],\n", + " ]\n", + ")\n", "alphas = render_ball(ball_parameters)\n", "b.get_depth_image(alphas)" ] @@ -798,7 +882,7 @@ " pbar.set_description(\"Processing %f\" % loss_val.item())\n", " ball_parameters -= gradient_quat * 0.005\n", "alphas = render_ball(ball_parameters)\n", - "b.get_depth_image(alphas,min=0.0,max=intrinsics.far)" + "b.get_depth_image(alphas, min=0.0, max=intrinsics.far)" ] }, { @@ -810,9 +894,9 @@ "source": [ "alphas = render_ball(ball_parameters)\n", "b.clear()\n", - "b.show_cloud(\"1\", unproject_depth(alphas, intrinsics).reshape(-1,3))\n", - "b.show_cloud(\"gt\", unproject_depth(gt_depth, intrinsics).reshape(-1,3), color=b.RED)\n", - "b.get_depth_image(alphas,min=0.0,max=intrinsics.far)" + "b.show_cloud(\"1\", unproject_depth(alphas, intrinsics).reshape(-1, 3))\n", + "b.show_cloud(\"gt\", unproject_depth(gt_depth, intrinsics).reshape(-1, 3), color=b.RED)\n", + "b.get_depth_image(alphas, min=0.0, max=intrinsics.far)" ] }, { @@ -822,7 +906,7 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(ball_parameters[:,-1])" + "plt.plot(ball_parameters[:, -1])" ] }, { @@ -832,7 +916,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_depth_image(render_ball(ball_parameters[2:]),min=0.0,max=intrinsics.far)" + "b.get_depth_image(render_ball(ball_parameters[2:]), min=0.0, max=intrinsics.far)" ] }, { diff --git a/scripts/experiments/gaussian_splatting/viz_splat.ipynb b/scripts/experiments/gaussian_splatting/viz_splat.ipynb index 124c68fa..fcb9f32c 100644 --- a/scripts/experiments/gaussian_splatting/viz_splat.ipynb +++ b/scripts/experiments/gaussian_splatting/viz_splat.ipynb @@ -14,7 +14,10 @@ "outputs": [], "source": [ "import diff_gaussian_rasterization as dgr\n", - "from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\n", + "from diff_gaussian_rasterization import (\n", + " GaussianRasterizationSettings,\n", + " GaussianRasterizer,\n", + ")\n", "import torch\n", "import os\n", "import numpy as np\n", @@ -65,12 +68,13 @@ " An Ellipsoid is treated as a Sphere of unit radius, with an affine\n", " transformation applied to distort it into the ellipsoidal shape\n", " \"\"\"\n", + "\n", " def __init__(self, transform):\n", " super(Ellipsoid2, self).__init__(1.0)\n", " self.transform = np.array(transform)\n", "\n", " def intrinsic_transform(self):\n", - " return self.transform\n" + " return self.transform" ] }, { @@ -89,12 +93,18 @@ "outputs": [], "source": [ "from plyfile import PlyData, PlyElement\n", + "\n", "path = \"/home/nishadgothoskar/gaussian-splatting/output/93d9b62d-d/point_cloud/iteration_30000/point_cloud.ply\"\n", "plydata = PlyData.read(path)\n", "\n", - "xyz = np.stack((np.asarray(plydata.elements[0][\"x\"]),\n", - " np.asarray(plydata.elements[0][\"y\"]),\n", - " np.asarray(plydata.elements[0][\"z\"])), axis=1)\n", + "xyz = np.stack(\n", + " (\n", + " np.asarray(plydata.elements[0][\"x\"]),\n", + " np.asarray(plydata.elements[0][\"y\"]),\n", + " np.asarray(plydata.elements[0][\"z\"]),\n", + " ),\n", + " axis=1,\n", + ")\n", "opacities = np.asarray(plydata.elements[0][\"opacity\"])[..., np.newaxis]\n", "\n", "features_dc = np.zeros((xyz.shape[0], 3, 1))\n", @@ -103,14 +113,16 @@ "features_dc[:, 2, 0] = np.asarray(plydata.elements[0][\"f_dc_2\"])\n", "\n", "\n", - "scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"scale_\")]\n", - "scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))\n", + "scale_names = [\n", + " p.name for p in plydata.elements[0].properties if p.name.startswith(\"scale_\")\n", + "]\n", + "scale_names = sorted(scale_names, key=lambda x: int(x.split(\"_\")[-1]))\n", "scales = np.zeros((xyz.shape[0], len(scale_names)))\n", "for idx, attr_name in enumerate(scale_names):\n", " scales[:, idx] = np.asarray(plydata.elements[0][attr_name])\n", "\n", "rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"rot\")]\n", - "rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))\n", + "rot_names = sorted(rot_names, key=lambda x: int(x.split(\"_\")[-1]))\n", "rots = np.zeros((xyz.shape[0], len(rot_names)))\n", "for idx, attr_name in enumerate(rot_names):\n", " rots[:, idx] = np.asarray(plydata.elements[0][attr_name])" @@ -122,7 +134,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"xyz\", xyz/ 5.0)" + "b.show_cloud(\"xyz\", xyz / 5.0)" ] }, { @@ -173,8 +185,8 @@ "z = np.array([0.0, 0.0, 1.0])\n", "rot = np.hstack([x[:, None], y[:, None], z[:, None]])\n", "mat = np.eye(4)\n", - "mat[:3,:3] = rot\n", - "mat[:3,3] = np.array([13.0, 6.0, 5.0])\n", + "mat[:3, :3] = rot\n", + "mat[:3, 3] = np.array([13.0, 6.0, 5.0])\n", "mat\n", "print(mat)\n", "shape = Ellipsoid2(mat)\n", @@ -187,7 +199,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.show_cloud(\"1\", np.zeros((10,3)))" + "b.show_cloud(\"1\", np.zeros((10, 3)))" ] }, { diff --git a/scripts/experiments/gradient_c2f/pose.ipynb b/scripts/experiments/gradient_c2f/pose.ipynb index e564325f..875c930c 100644 --- a/scripts/experiments/gradient_c2f/pose.ipynb +++ b/scripts/experiments/gradient_c2f/pose.ipynb @@ -8,9 +8,10 @@ "outputs": [], "source": [ "import sys\n", - "sys.path.append('/workspace/bayes3d')\n", - "sys.path.append('/workspace/nvdiffrast')\n", - "sys.path.append('/workspace/nvdiffrast/samples/torch') # for `import util`\n", + "\n", + "sys.path.append(\"/workspace/bayes3d\")\n", + "sys.path.append(\"/workspace/nvdiffrast\")\n", + "sys.path.append(\"/workspace/nvdiffrast/samples/torch\") # for `import util`\n", "# sys.path.append('/workspace/nvdiffrast/nvdiffrast/torch') # for 'nvdiffrast.torch'" ] }, @@ -56,7 +57,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.t3d # sanity check import" + "b.t3d # sanity check import" ] }, { @@ -76,22 +77,22 @@ "metadata": {}, "outputs": [], "source": [ - "max_iter = 10000\n", - "repeats = 1\n", - "log_interval = 10\n", - "display_interval = None\n", - "display_res = 512\n", - "lr_base = 1e-3\n", - "lr_falloff = 1.0\n", - "nr_base = 1.0\n", - "nr_falloff = 1e-4\n", - "grad_phase_start = 0.5\n", - "resolution = 256\n", - "out_dir = None\n", - "log_fn = None\n", - "mp4save_interval = None\n", - "mp4save_fn = None\n", - "use_opengl = False" + "max_iter = 10000\n", + "repeats = 1\n", + "log_interval = 10\n", + "display_interval = None\n", + "display_res = 512\n", + "lr_base = 1e-3\n", + "lr_falloff = 1.0\n", + "nr_base = 1.0\n", + "nr_falloff = 1e-4\n", + "grad_phase_start = 0.5\n", + "resolution = 256\n", + "out_dir = None\n", + "log_fn = None\n", + "mp4save_interval = None\n", + "mp4save_fn = None\n", + "use_opengl = False" ] }, { @@ -101,8 +102,11 @@ "metadata": {}, "outputs": [], "source": [ - "glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext()\n", - "mvp = torch.tensor(np.matmul(util.projection(x=0.4), util.translate(0, 0, -3.5)).astype(np.float32), device='cuda')" + "glctx = dr.RasterizeGLContext() # if use_opengl else dr.RasterizeCudaContext()\n", + "mvp = torch.tensor(\n", + " np.matmul(util.projection(x=0.4), util.translate(0, 0, -3.5)).astype(np.float32),\n", + " device=\"cuda\",\n", + ")" ] }, { @@ -112,33 +116,65 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "#----------------------------------------------------------------------------\n", + "# ----------------------------------------------------------------------------\n", "# Quaternion math.\n", - "#----------------------------------------------------------------------------\n", + "# ----------------------------------------------------------------------------\n", "\n", "# Unit quaternion.\n", "def q_unit():\n", " return np.asarray([1, 0, 0, 0], np.float32)\n", "\n", + "\n", "# Get a random normalized quaternion.\n", "def q_rnd():\n", " u, v, w = np.random.uniform(0.0, 1.0, size=[3])\n", " v *= 2.0 * np.pi\n", " w *= 2.0 * np.pi\n", - " return np.asarray([(1.0-u)**0.5 * np.sin(v), (1.0-u)**0.5 * np.cos(v), u**0.5 * np.sin(w), u**0.5 * np.cos(w)], np.float32)\n", + " return np.asarray(\n", + " [\n", + " (1.0 - u) ** 0.5 * np.sin(v),\n", + " (1.0 - u) ** 0.5 * np.cos(v),\n", + " u**0.5 * np.sin(w),\n", + " u**0.5 * np.cos(w),\n", + " ],\n", + " np.float32,\n", + " )\n", + "\n", "\n", "# Get a random quaternion from the octahedral symmetric group S_4.\n", "_r2 = 0.5**0.5\n", - "_q_S4 = [[ 1.0, 0.0, 0.0, 0.0], [ 0.0, 1.0, 0.0, 0.0], [ 0.0, 0.0, 1.0, 0.0], [ 0.0, 0.0, 0.0, 1.0],\n", - " [-0.5, 0.5, 0.5, 0.5], [-0.5,-0.5,-0.5, 0.5], [ 0.5,-0.5, 0.5, 0.5], [ 0.5, 0.5,-0.5, 0.5],\n", - " [ 0.5, 0.5, 0.5, 0.5], [-0.5, 0.5,-0.5, 0.5], [ 0.5,-0.5,-0.5, 0.5], [-0.5,-0.5, 0.5, 0.5],\n", - " [ _r2,-_r2, 0.0, 0.0], [ _r2, _r2, 0.0, 0.0], [ 0.0, 0.0, _r2, _r2], [ 0.0, 0.0,-_r2, _r2],\n", - " [ 0.0, _r2, _r2, 0.0], [ _r2, 0.0, 0.0,-_r2], [ _r2, 0.0, 0.0, _r2], [ 0.0,-_r2, _r2, 0.0],\n", - " [ _r2, 0.0, _r2, 0.0], [ 0.0, _r2, 0.0, _r2], [ _r2, 0.0,-_r2, 0.0], [ 0.0,-_r2, 0.0, _r2]]\n", + "_q_S4 = [\n", + " [1.0, 0.0, 0.0, 0.0],\n", + " [0.0, 1.0, 0.0, 0.0],\n", + " [0.0, 0.0, 1.0, 0.0],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " [-0.5, 0.5, 0.5, 0.5],\n", + " [-0.5, -0.5, -0.5, 0.5],\n", + " [0.5, -0.5, 0.5, 0.5],\n", + " [0.5, 0.5, -0.5, 0.5],\n", + " [0.5, 0.5, 0.5, 0.5],\n", + " [-0.5, 0.5, -0.5, 0.5],\n", + " [0.5, -0.5, -0.5, 0.5],\n", + " [-0.5, -0.5, 0.5, 0.5],\n", + " [_r2, -_r2, 0.0, 0.0],\n", + " [_r2, _r2, 0.0, 0.0],\n", + " [0.0, 0.0, _r2, _r2],\n", + " [0.0, 0.0, -_r2, _r2],\n", + " [0.0, _r2, _r2, 0.0],\n", + " [_r2, 0.0, 0.0, -_r2],\n", + " [_r2, 0.0, 0.0, _r2],\n", + " [0.0, -_r2, _r2, 0.0],\n", + " [_r2, 0.0, _r2, 0.0],\n", + " [0.0, _r2, 0.0, _r2],\n", + " [_r2, 0.0, -_r2, 0.0],\n", + " [0.0, -_r2, 0.0, _r2],\n", + "]\n", + "\n", + "\n", "def q_rnd_S4():\n", " return np.asarray(_q_S4[np.random.randint(24)], np.float32)\n", "\n", + "\n", "# Quaternion slerp.\n", "def q_slerp(p, q, t):\n", " d = np.dot(p, q)\n", @@ -146,28 +182,31 @@ " q = -q\n", " d = -d\n", " if d > 0.999:\n", - " a = p + t * (q-p)\n", + " a = p + t * (q - p)\n", " return a / np.linalg.norm(a)\n", " t0 = np.arccos(d)\n", " tt = t0 * t\n", " st = np.sin(tt)\n", " st0 = np.sin(t0)\n", " s1 = st / st0\n", - " s0 = np.cos(tt) - d*s1\n", - " return s0*p + s1*q\n", + " s0 = np.cos(tt) - d * s1\n", + " return s0 * p + s1 * q\n", + "\n", "\n", "# Quaterion scale (slerp vs. identity quaternion).\n", "def q_scale(q, scl):\n", " return q_slerp(q_unit(), q, scl)\n", "\n", + "\n", "# Quaternion product.\n", "def q_mul(p, q):\n", " s1, V1 = p[0], p[1:]\n", " s2, V2 = q[0], q[1:]\n", - " s = s1*s2 - np.dot(V1, V2)\n", - " V = s1*V2 + s2*V1 + np.cross(V1, V2)\n", + " s = s1 * s2 - np.dot(V1, V2)\n", + " V = s1 * V2 + s2 * V1 + np.cross(V1, V2)\n", " return np.asarray([s, V[0], V[1], V[2]], np.float32)\n", "\n", + "\n", "# Angular difference between two quaternions in degrees.\n", "def q_angle_deg(p, q):\n", " p = p.detach().cpu().numpy()\n", @@ -176,24 +215,49 @@ " d = min(d, 1.0)\n", " return np.degrees(2.0 * np.arccos(d))\n", "\n", + "\n", "# Quaternion product\n", "def q_mul_torch(p, q):\n", - " a = p[0]*q[0] - p[1]*q[1] - p[2]*q[2] - p[3]*q[3]\n", - " b = p[0]*q[1] + p[1]*q[0] + p[2]*q[3] - p[3]*q[2]\n", - " c = p[0]*q[2] + p[2]*q[0] + p[3]*q[1] - p[1]*q[3]\n", - " d = p[0]*q[3] + p[3]*q[0] + p[1]*q[2] - p[2]*q[1]\n", + " a = p[0] * q[0] - p[1] * q[1] - p[2] * q[2] - p[3] * q[3]\n", + " b = p[0] * q[1] + p[1] * q[0] + p[2] * q[3] - p[3] * q[2]\n", + " c = p[0] * q[2] + p[2] * q[0] + p[3] * q[1] - p[1] * q[3]\n", + " d = p[0] * q[3] + p[3] * q[0] + p[1] * q[2] - p[2] * q[1]\n", " return torch.stack([a, b, c, d])\n", "\n", + "\n", "# Convert quaternion to 4x4 rotation matrix.\n", "def q_to_mtx(q):\n", - " r0 = torch.stack([1.0-2.0*q[1]**2 - 2.0*q[2]**2, 2.0*q[0]*q[1] - 2.0*q[2]*q[3], 2.0*q[0]*q[2] + 2.0*q[1]*q[3]])\n", - " r1 = torch.stack([2.0*q[0]*q[1] + 2.0*q[2]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[2]**2, 2.0*q[1]*q[2] - 2.0*q[0]*q[3]])\n", - " r2 = torch.stack([2.0*q[0]*q[2] - 2.0*q[1]*q[3], 2.0*q[1]*q[2] + 2.0*q[0]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[1]**2])\n", + " r0 = torch.stack(\n", + " [\n", + " 1.0 - 2.0 * q[1] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[0] * q[1] - 2.0 * q[2] * q[3],\n", + " 2.0 * q[0] * q[2] + 2.0 * q[1] * q[3],\n", + " ]\n", + " )\n", + " r1 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[1] + 2.0 * q[2] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[1] * q[2] - 2.0 * q[0] * q[3],\n", + " ]\n", + " )\n", + " r2 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[2] - 2.0 * q[1] * q[3],\n", + " 2.0 * q[1] * q[2] + 2.0 * q[0] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[1] ** 2,\n", + " ]\n", + " )\n", " rr = torch.transpose(torch.stack([r0, r1, r2]), 1, 0)\n", - " rr = torch.cat([rr, torch.tensor([[0], [0], [0]], dtype=torch.float32).cuda()], dim=1) # Pad right column.\n", - " rr = torch.cat([rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0) # Pad bottom row.\n", + " rr = torch.cat(\n", + " [rr, torch.tensor([[0], [0], [0]], dtype=torch.float32).cuda()], dim=1\n", + " ) # Pad right column.\n", + " rr = torch.cat(\n", + " [rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0\n", + " ) # Pad bottom row.\n", " return rr\n", "\n", + "\n", "# Transform vertex positions to clip space\n", "def transform_pos(mtx, pos):\n", " t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx, np.ndarray) else mtx\n", @@ -201,35 +265,61 @@ " posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)\n", " return torch.matmul(posw, t_mtx.t())[None, ...]\n", "\n", + "\n", "def render(glctx, mtx, pos, pos_idx, col, col_idx, resolution: int):\n", " # Setup TF graph for reference.\n", " depth_ = pos[..., 2:3]\n", - " depth = torch.tensor([[[(z_val/1)] for z_val in depth_.squeeze()]], dtype=torch.float32).cuda()\n", - " pos_clip = transform_pos(mtx, pos)\n", - " rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[resolution, resolution])\n", - " color , _ = dr.interpolate(depth, rast_out, pos_idx)\n", + " depth = torch.tensor(\n", + " [[[(z_val / 1)] for z_val in depth_.squeeze()]], dtype=torch.float32\n", + " ).cuda()\n", + " pos_clip = transform_pos(mtx, pos)\n", + " rast_out, _ = dr.rasterize(\n", + " glctx, pos_clip, pos_idx, resolution=[resolution, resolution]\n", + " )\n", + " color, _ = dr.interpolate(depth, rast_out, pos_idx)\n", " # color = dr.antialias(color, rast_out, pos_clip, pos_idx)\n", " return color\n", " # return rast_out[:,:,:,2:3]\n", - " \n", "\n", - " \n", + "\n", "################ Added ######################\n", "from scipy.spatial.transform import Rotation as R\n", "\n", + "\n", "# Convert quaternion and position vector to 4x4 rotation matrix.\n", "def q_v_to_mtx(q, v):\n", - " r0 = torch.stack([1.0-2.0*q[1]**2 - 2.0*q[2]**2, 2.0*q[0]*q[1] - 2.0*q[2]*q[3], 2.0*q[0]*q[2] + 2.0*q[1]*q[3]])\n", - " r1 = torch.stack([2.0*q[0]*q[1] + 2.0*q[2]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[2]**2, 2.0*q[1]*q[2] - 2.0*q[0]*q[3]])\n", - " r2 = torch.stack([2.0*q[0]*q[2] - 2.0*q[1]*q[3], 2.0*q[1]*q[2] + 2.0*q[0]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[1]**2])\n", + " r0 = torch.stack(\n", + " [\n", + " 1.0 - 2.0 * q[1] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[0] * q[1] - 2.0 * q[2] * q[3],\n", + " 2.0 * q[0] * q[2] + 2.0 * q[1] * q[3],\n", + " ]\n", + " )\n", + " r1 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[1] + 2.0 * q[2] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[1] * q[2] - 2.0 * q[0] * q[3],\n", + " ]\n", + " )\n", + " r2 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[2] - 2.0 * q[1] * q[3],\n", + " 2.0 * q[1] * q[2] + 2.0 * q[0] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[1] ** 2,\n", + " ]\n", + " )\n", " rr = torch.transpose(torch.stack([r0, r1, r2]), 1, 0)\n", - " rr = torch.cat([rr, torch.reshape(v, (3,1))], dim=1) \n", - " rr = torch.cat([rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0) # Pad bottom row.\n", + " rr = torch.cat([rr, torch.reshape(v, (3, 1))], dim=1)\n", + " rr = torch.cat(\n", + " [rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0\n", + " ) # Pad bottom row.\n", " return rr\n", "\n", + "\n", "# Convert quaternion and position vector to 4x4 rotation matrix.\n", "def q_v_to_mtx_batch(qs, vs):\n", - " return torch.stack([q_v_to_mtx(q, v) for q,v in zip(qs, vs)])\n", + " return torch.stack([q_v_to_mtx(q, v) for q, v in zip(qs, vs)])\n", "\n", "\n", "# Get a random position near the origin.\n", @@ -237,11 +327,16 @@ " x, y, z = np.random.uniform(-0.005, 0.005, size=[3])\n", " return np.asarray([x, y, z], np.float32)\n", "\n", + "\n", "# Multiple renders\n", "def render_multiple(glctx, poses, vtx_pos, pos_idx, vtx_col, col_idx, resolution):\n", - " ret = torch.cat([render(glctx, pose, vtx_pos, \n", - " pos_idx, vtx_col, \n", - " col_idx, resolution) for pose in poses], axis=0)\n", + " ret = torch.cat(\n", + " [\n", + " render(glctx, pose, vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", + " for pose in poses\n", + " ],\n", + " axis=0,\n", + " )\n", " return ret" ] }, @@ -253,13 +348,14 @@ "outputs": [], "source": [ "datadir = \"/workspace/nvdiffrast/samples/data/\"\n", - "with np.load(f'{datadir}/cube_p.npz') as f:\n", + "with np.load(f\"{datadir}/cube_p.npz\") as f:\n", " pos_idx, pos, col_idx, col = f.values()\n", "print(\"Mesh has %d triangles and %d vertices.\" % (pos_idx.shape[0], pos.shape[0]))\n", "\n", "# Some input geometry contains vertex positions in (N, 4) (with v[:,3]==1). Drop\n", "# the last column in that case.\n", - "if pos.shape[1] == 4: pos = pos[:, 0:3]\n", + "if pos.shape[1] == 4:\n", + " pos = pos[:, 0:3]\n", "\n", "# Create position/triangle index tensors\n", "pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()\n", @@ -297,7 +393,7 @@ "metadata": {}, "outputs": [], "source": [ - "UNIT_VECTOR = np.array([0,0,1]) # unit for viz" + "UNIT_VECTOR = np.array([0, 0, 1]) # unit for viz" ] }, { @@ -310,49 +406,66 @@ "def fig2img(fig):\n", " \"\"\"Convert a Matplotlib figure to a PIL Image and return it\"\"\"\n", " import io\n", + "\n", " buf = io.BytesIO()\n", " fig.savefig(buf)\n", " buf.seek(0)\n", " img = Image.open(buf)\n", " return img\n", "\n", + "\n", "def plot_polar_angles_on_frame(thetas, phis, curr_ax):\n", " ax = curr_ax\n", " scaling = 0.96\n", " for theta in thetas:\n", " for phi in phis:\n", - " x = np.cos(phi)*np.cos(theta)\n", - " y = np.cos(phi)*np.sin(theta)\n", + " x = np.cos(phi) * np.cos(theta)\n", + " y = np.cos(phi) * np.sin(theta)\n", " z = np.sin(phi)\n", " ax.scatter(x * scaling, y, z, s=5**2, color=\"red\", alpha=1)\n", - " u, v, w = 1,0,0\n", + " u, v, w = 1, 0, 0\n", + "\n", + "\n", "# ax.quiver(x, y, z, u, v, w, length=0.1, normalize=True, alpha=0.8)\n", "# _, ax = generate_sphere_plot()\n", "# phis = np.arange(0, np.pi, np.pi/10)\n", "# thetas = [0]\n", "# plot_polar_angles_on_frame(thetas, phis, ax)\n", "\n", + "\n", "def plot_cartesian_point_on_frame(point, curr_ax, color=\"red\", alpha=1):\n", " ax = curr_ax\n", " x, y, z = point\n", " ax.scatter(x, y, z, s=5**2, color=color, alpha=alpha)\n", - " \n", - "def plot_rot_and_pos(rot_pt, pos_pt, ax_r, ax_p, color=\"red\", alpha=1, label=None, rot_title=None, pos_title=None):\n", + "\n", + "\n", + "def plot_rot_and_pos(\n", + " rot_pt,\n", + " pos_pt,\n", + " ax_r,\n", + " ax_p,\n", + " color=\"red\",\n", + " alpha=1,\n", + " label=None,\n", + " rot_title=None,\n", + " pos_title=None,\n", + "):\n", " \"\"\"Given points on the spherical coord and the cartesian coord,\n", " Plot on the corresponding rotation and position axes\"\"\"\n", " rx, ry, rz = rot_pt[..., 0], rot_pt[..., 1], rot_pt[..., 2]\n", " px, py, pz = pos_pt[..., 0], pos_pt[..., 1], pos_pt[..., 2]\n", " ax_r.scatter(rx, ry, rz, s=5**2, color=color, alpha=alpha, label=label)\n", " ax_p.scatter(px, py, pz, s=5**2, color=color, alpha=alpha, label=label)\n", - " \n", + "\n", " if label is not None:\n", " ax_r.legend()\n", - "# ax_p.legend(loc=\"upper left\")\n", + " # ax_p.legend(loc=\"upper left\")\n", " if rot_title is not None:\n", " ax_r.set_title(rot_title)\n", " if pos_title is not None:\n", " ax_p.set_title(pos_title)\n", "\n", + "\n", "# _, ax = generate_sphere_plot(show_unit=True)" ] }, @@ -366,25 +479,25 @@ "def generate_sphere_plot(show_unit=False, fig_ax=None):\n", " if fig_ax is None:\n", " fig = plt.figure()\n", - " ax = fig.add_subplot(projection='3d')\n", + " ax = fig.add_subplot(projection=\"3d\")\n", " else:\n", " fig, ax = fig_ax\n", - " \n", + "\n", " ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", " ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", " ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", " # make the grid lines transparent\n", - " ax.xaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - " ax.yaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - " ax.zaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - " u, v = np.mgrid[0:2*np.pi:21j, 0:np.pi:11j]\n", - " x = np.cos(u)*np.sin(v)\n", - " y = np.sin(u)*np.sin(v)\n", + " ax.xaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + " ax.yaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + " ax.zaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + " u, v = np.mgrid[0 : 2 * np.pi : 21j, 0 : np.pi : 11j]\n", + " x = np.cos(u) * np.sin(v)\n", + " y = np.sin(u) * np.sin(v)\n", " z = np.cos(v)\n", " ax.set_axis_off()\n", - " ax.axes.set_xlim3d(-1.05, 1.05) \n", - " ax.axes.set_ylim3d(-1.05, 1.05) \n", - " ax.axes.set_zlim3d(-1.05, 1.05) \n", + " ax.axes.set_xlim3d(-1.05, 1.05)\n", + " ax.axes.set_ylim3d(-1.05, 1.05)\n", + " ax.axes.set_zlim3d(-1.05, 1.05)\n", " ax.set_aspect(\"equal\")\n", " ax.plot_wireframe(x, y, z, color=(0.0, 0.0, 0.0, 0.3), linewidths=0.5)\n", "\n", @@ -393,42 +506,50 @@ " ax.axes.set_zlabel(\"z\")\n", "\n", " if show_unit:\n", - " quat_unit = q_to_mtx(torch.tensor([1,0,0,0], device=\"cuda\", dtype=torch.float64)).cpu()[:3, :3] @ torch.tensor(UNIT_VECTOR, dtype=torch.float64) \n", + " quat_unit = q_to_mtx(\n", + " torch.tensor([1, 0, 0, 0], device=\"cuda\", dtype=torch.float64)\n", + " ).cpu()[:3, :3] @ torch.tensor(UNIT_VECTOR, dtype=torch.float64)\n", " ax.scatter(quat_unit[0], quat_unit[1], quat_unit[2], color=\"green\", alpha=1)\n", " return fig, ax\n", + "\n", + "\n", "# _, ax = generate_sphere_plot(show_unit=True)\n", "\n", + "\n", "def generate_cartesian_plot(show_unit=False, fig_ax=None):\n", " if fig_ax is None:\n", " fig = plt.figure()\n", - " ax = fig.add_subplot(projection='3d')\n", + " ax = fig.add_subplot(projection=\"3d\")\n", " else:\n", " fig, ax = fig_ax\n", - "# ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", - "# ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", - "# ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", - " \n", - " ax.axes.set_xlim3d(-1.5, 1.5) \n", - " ax.axes.set_ylim3d(-1.5, 1.5) \n", - " ax.axes.set_zlim3d(-1.5, 1.5) \n", - " \n", + " # ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", + " # ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", + " # ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", + "\n", + " ax.axes.set_xlim3d(-1.5, 1.5)\n", + " ax.axes.set_ylim3d(-1.5, 1.5)\n", + " ax.axes.set_zlim3d(-1.5, 1.5)\n", + "\n", " ax.axes.set_xlabel(\"x\")\n", " ax.axes.set_ylabel(\"y\")\n", " ax.axes.set_zlabel(\"z\")\n", - " \n", + "\n", " if show_unit:\n", " ax.scatter(0.0, 0.0, 0.0, s=5**2, color=\"green\", alpha=1)\n", " return fig, ax\n", + "\n", + "\n", "# _, ax = generate_cartesian_plot(True)\n", "\n", + "\n", "def generate_rotation_translation_plot(show_unit=False):\n", " # set up a figure twice as wide as it is tall\n", " fig = plt.figure(figsize=plt.figaspect(0.5))\n", "\n", " # set up the axes for the first plot\n", - " ax1 = fig.add_subplot(1, 2, 1, projection='3d')\n", - " ax2 = fig.add_subplot(1, 2, 2, projection='3d')\n", - " \n", + " ax1 = fig.add_subplot(1, 2, 1, projection=\"3d\")\n", + " ax2 = fig.add_subplot(1, 2, 2, projection=\"3d\")\n", + "\n", " # Generate the subplots\n", " _, _ = generate_sphere_plot(show_unit, (fig, ax1))\n", " _, _ = generate_cartesian_plot(show_unit, (fig, ax2))\n", @@ -436,11 +557,11 @@ " # Label with title\n", " ax1.set_title(\"Rotation evolution\")\n", " ax2.set_title(\"Translation evolution\")\n", - " \n", + "\n", " return fig, (ax1, ax2)\n", "\n", "\n", - "_, (ax_r, ax_t) = generate_rotation_translation_plot(False)\n" + "_, (ax_r, ax_t) = generate_rotation_translation_plot(False)" ] }, { @@ -459,7 +580,7 @@ " torch.backends.cudnn.benchmark = False\n", " # Set a fixed value for the hash seed\n", " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", - " print(f\"Random seed set as {seed}\")\n" + " print(f\"Random seed set as {seed}\")" ] }, { @@ -496,14 +617,22 @@ "outputs": [], "source": [ "# GT) Rotation, Position poses\n", - "pose_rot_target = torch.tensor(q_rnd(), device='cuda')\n", - "pose_pos_target = torch.tensor(v_rnd(), device='cuda')\n", + "pose_rot_target = torch.tensor(q_rnd(), device=\"cuda\")\n", + "pose_pos_target = torch.tensor(v_rnd(), device=\"cuda\")\n", "pose_target = q_v_to_mtx(pose_rot_target, pose_pos_target)\n", "print(\"TARGET POSE=\", pose_target)\n", "\n", "# Initial GT render\n", - "rast_target = render(glctx, torch.matmul(mvp, pose_target), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_target = rast_target[0].detach().cpu().numpy()" + "rast_target = render(\n", + " glctx,\n", + " torch.matmul(mvp, pose_target),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + ")\n", + "img_target = rast_target[0].detach().cpu().numpy()" ] }, { @@ -522,31 +651,48 @@ "outputs": [], "source": [ "# Opt) Rotation, Position poses\n", - "pose_rot_init = pose_rot_target.cpu().numpy() + 0.1 #+ 1.5\n", - "pose_pos_init = pose_pos_target.cpu().numpy() + np.random.rand(3,)/5\n", - "\n", - "pose_rot_opt = torch.tensor(pose_rot_init / np.sum(pose_rot_init**2)**0.5, dtype=torch.float32, device='cuda', requires_grad=True)\n", - "pose_pos_opt = torch.tensor(pose_pos_init, dtype=torch.float32, device='cuda', requires_grad=True)\n", + "pose_rot_init = pose_rot_target.cpu().numpy() + 0.1 # + 1.5\n", + "pose_pos_init = (\n", + " pose_pos_target.cpu().numpy()\n", + " + np.random.rand(\n", + " 3,\n", + " )\n", + " / 5\n", + ")\n", + "\n", + "pose_rot_opt = torch.tensor(\n", + " pose_rot_init / np.sum(pose_rot_init**2) ** 0.5,\n", + " dtype=torch.float32,\n", + " device=\"cuda\",\n", + " requires_grad=True,\n", + ")\n", + "pose_pos_opt = torch.tensor(\n", + " pose_pos_init, dtype=torch.float32, device=\"cuda\", requires_grad=True\n", + ")\n", "print(pose_rot_opt, pose_pos_opt)\n", "\n", "pose_opt = q_v_to_mtx(pose_rot_opt, pose_pos_opt)\n", "print(pose_opt.shape)\n", "\n", - "# initialize loss \n", - "loss_best = np.inf\n", + "# initialize loss\n", + "loss_best = np.inf\n", "\n", "# Visualize initial state\n", "print(f\"target pose={pose_target},\\ncurrent pose={pose_opt}\")\n", "\n", "# Initial opt render\n", - "rast_opt = render(glctx, torch.matmul(mvp, pose_opt), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_opt = rast_opt[0].detach().cpu().numpy()\n", + "rast_opt = render(\n", + " glctx, torch.matmul(mvp, pose_opt), vtx_pos, pos_idx, vtx_col, col_idx, resolution\n", + ")\n", + "img_opt = rast_opt[0].detach().cpu().numpy()\n", "print(rast_opt.shape, img_opt.shape)\n", "\n", - "b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])" + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -564,70 +710,110 @@ "metadata": {}, "outputs": [], "source": [ - "def descend_gradient(pose_rot_opt, pose_pos_opt, pose_target, rast_opt, rast_target, it=20, verbose=False, plot=False):\n", + "def descend_gradient(\n", + " pose_rot_opt,\n", + " pose_pos_opt,\n", + " pose_target,\n", + " rast_opt,\n", + " rast_target,\n", + " it=20,\n", + " verbose=False,\n", + " plot=False,\n", + "):\n", " OPTIM_GIF_IMGS = []\n", " loss_best = np.inf\n", - " optimizer = torch.optim.Adam([pose_rot_opt, pose_pos_opt], betas=(0.9, 0.999), lr=lr_base)\n", - " \n", - " if plot: \n", + " optimizer = torch.optim.Adam(\n", + " [pose_rot_opt, pose_pos_opt], betas=(0.9, 0.999), lr=lr_base\n", + " )\n", + "\n", + " if plot:\n", " fig, (ax_r, ax_t) = generate_rotation_translation_plot(False)\n", - " plot_rot_and_pos(pose_opt.detach().cpu()[:3, :3] @ UNIT_VECTOR, \n", - " pose_opt.detach().cpu()[:3, -1], \n", - " ax_r, ax_t, \n", - " color=\"green\", alpha=1, label=\"Initial\")\n", - " plot_rot_and_pos(pose_opt.detach().cpu()[:3, :3] @ UNIT_VECTOR, \n", - " pose_opt.detach().cpu()[:3, -1], \n", - " ax_r, ax_t, \n", - " color=\"blue\", alpha=0.1, label=\"Hypothesis\")\n", - " plot_rot_and_pos(pose_target.detach().cpu()[:3, :3] @ UNIT_VECTOR, \n", - " pose_target.detach().cpu()[:3, -1], \n", - " ax_r, ax_t, \n", - " color=\"red\", alpha=1, label=\"Target\")\n", - "\n", - " \n", + " plot_rot_and_pos(\n", + " pose_opt.detach().cpu()[:3, :3] @ UNIT_VECTOR,\n", + " pose_opt.detach().cpu()[:3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"green\",\n", + " alpha=1,\n", + " label=\"Initial\",\n", + " )\n", + " plot_rot_and_pos(\n", + " pose_opt.detach().cpu()[:3, :3] @ UNIT_VECTOR,\n", + " pose_opt.detach().cpu()[:3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"blue\",\n", + " alpha=0.1,\n", + " label=\"Hypothesis\",\n", + " )\n", + " plot_rot_and_pos(\n", + " pose_target.detach().cpu()[:3, :3] @ UNIT_VECTOR,\n", + " pose_target.detach().cpu()[:3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"red\",\n", + " alpha=1,\n", + " label=\"Target\",\n", + " )\n", + "\n", " for i in tqdm(range(it)):\n", " noise = q_unit()\n", " pose_rot_total_opt = q_mul_torch(pose_rot_opt, noise)\n", - " mtx_total_opt = torch.matmul(mvp, q_v_to_mtx(pose_rot_total_opt, pose_pos_opt))\n", - " color_opt = render(glctx, mtx_total_opt, vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", + " mtx_total_opt = torch.matmul(mvp, q_v_to_mtx(pose_rot_total_opt, pose_pos_opt))\n", + " color_opt = render(\n", + " glctx, mtx_total_opt, vtx_pos, pos_idx, vtx_col, col_idx, resolution\n", + " )\n", "\n", - " diff = (rast_opt - rast_target)**2 # L2 norm.\n", + " diff = (rast_opt - rast_target) ** 2 # L2 norm.\n", " diff = torch.tanh(5.0 * torch.max(diff, dim=-1)[0])\n", " loss = torch.mean(diff)\n", " loss_val = float(loss)\n", "\n", " if (loss_val < loss_best) and (loss_val > 0.0):\n", " loss_best = loss_val\n", - " if (loss_val/loss_best > 1.2):\n", + " if loss_val / loss_best > 1.2:\n", " break\n", - " \n", + "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", - " \n", + "\n", " with torch.no_grad():\n", - " pose_rot_opt /= torch.sum(pose_rot_opt**2)**0.5\n", + " pose_rot_opt /= torch.sum(pose_rot_opt**2) ** 0.5\n", + "\n", + " rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, q_v_to_mtx(pose_rot_opt, pose_pos_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + " )\n", + " img_opt = rast_opt[0].detach().cpu().numpy()\n", "\n", - " rast_opt = render(glctx, torch.matmul(mvp, q_v_to_mtx(pose_rot_opt, pose_pos_opt)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - " img_opt = rast_opt[0].detach().cpu().numpy()\n", - " \n", - " \n", " if verbose:\n", " print(f\"loss={loss}, rot={pose_rot_total_opt}, pos={pose_pos_opt}\")\n", "\n", - " curr_render_imgs = b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - " ])\n", - " \n", + " curr_render_imgs = b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + " )\n", + "\n", " if plot:\n", - " pose_opt_curr_val = q_v_to_mtx(pose_rot_opt, pose_pos_opt).detach().cpu() \n", - " plot_rot_and_pos(pose_opt_curr_val[:3, :3] @ UNIT_VECTOR, \n", - " pose_opt_curr_val[:3, -1], \n", - " ax_r, ax_t, \n", - " color=\"blue\", alpha=0.1,\n", - " rot_title=f\"Rotation evolution, iter {i}\", \n", - " pos_title=f\"Translation evolution, iter {i}\") # current\n", + " pose_opt_curr_val = q_v_to_mtx(pose_rot_opt, pose_pos_opt).detach().cpu()\n", + " plot_rot_and_pos(\n", + " pose_opt_curr_val[:3, :3] @ UNIT_VECTOR,\n", + " pose_opt_curr_val[:3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"blue\",\n", + " alpha=0.1,\n", + " rot_title=f\"Rotation evolution, iter {i}\",\n", + " pos_title=f\"Translation evolution, iter {i}\",\n", + " ) # current\n", " curr_PIL = fig2img(fig)\n", " OPTIM_GIF_IMGS.append(b.hstack_images([curr_PIL, curr_render_imgs]))\n", " else:\n", @@ -635,7 +821,17 @@ "\n", " return OPTIM_GIF_IMGS\n", "\n", - "OPTIM_GIF_IMGS = descend_gradient(pose_rot_opt, pose_pos_opt, pose_target, rast_opt, rast_target, it=120, verbose=False, plot=True)\n", + "\n", + "OPTIM_GIF_IMGS = descend_gradient(\n", + " pose_rot_opt,\n", + " pose_pos_opt,\n", + " pose_target,\n", + " rast_opt,\n", + " rast_target,\n", + " it=120,\n", + " verbose=False,\n", + " plot=True,\n", + ")\n", "b.viz.make_gif_from_pil_images(OPTIM_GIF_IMGS, \"render_imgs.gif\")\n", "b.vstack_images([OPTIM_GIF_IMGS[0], OPTIM_GIF_IMGS[-1]])" ] @@ -674,17 +870,27 @@ "outputs": [], "source": [ "# GT) Rotation, Position poses\n", - "pose_rot_target = torch.tensor(q_rnd(), device='cuda')\n", - "pose_pos_target = torch.tensor(v_rnd(), device='cuda')\n", + "pose_rot_target = torch.tensor(q_rnd(), device=\"cuda\")\n", + "pose_pos_target = torch.tensor(v_rnd(), device=\"cuda\")\n", "pose_target = q_v_to_mtx(pose_rot_target, pose_pos_target)\n", "print(\"TARGET POSE=\", pose_target)\n", "\n", "# Initial GT render\n", - "rast_target = render(glctx, torch.matmul(mvp, pose_target), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_target = rast_target[0].detach().cpu().numpy()\n", - "b.hstack_images([\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])" + "rast_target = render(\n", + " glctx,\n", + " torch.matmul(mvp, pose_target),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + ")\n", + "img_target = rast_target[0].detach().cpu().numpy()\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -707,18 +913,29 @@ "####\n", "\n", "min_x, min_y, min_z, max_x, max_y, max_z = -0.05, -0.05, -0.05, 0.05, 0.05, 0.05\n", - "num_x, num_y, num_z = 1,1,1\n", - "min_rotation_angle, max_rotation_angle = -jnp.pi/10, jnp.pi/10\n", - "sphere_angle_range = jnp.pi/10\n", + "num_x, num_y, num_z = 1, 1, 1\n", + "min_rotation_angle, max_rotation_angle = -jnp.pi / 10, jnp.pi / 10\n", + "sphere_angle_range = jnp.pi / 10\n", "fibonacci_sphere_points = 2\n", "num_planar_angle_points = 10\n", "\n", "\n", - "pose_delta_enums_jax = b.utils.enumerations.make_pose_grid_enumeration(min_x,min_y,min_z, min_rotation_angle, \n", - " max_x,max_y,max_z, max_rotation_angle,\n", - " num_x,num_y,num_z, \n", - " fibonacci_sphere_points, num_planar_angle_points, \n", - " sphere_angle_range=sphere_angle_range)\n", + "pose_delta_enums_jax = b.utils.enumerations.make_pose_grid_enumeration(\n", + " min_x,\n", + " min_y,\n", + " min_z,\n", + " min_rotation_angle,\n", + " max_x,\n", + " max_y,\n", + " max_z,\n", + " max_rotation_angle,\n", + " num_x,\n", + " num_y,\n", + " num_z,\n", + " fibonacci_sphere_points,\n", + " num_planar_angle_points,\n", + " sphere_angle_range=sphere_angle_range,\n", + ")\n", "pose_delta_enums = torch.from_dlpack(jax.dlpack.to_dlpack(pose_delta_enums_jax, True))" ] }, @@ -752,29 +969,40 @@ "source": [ "# Opt) Rotation, Position poses\n", "r = R.from_matrix(pose_enums.detach().cpu()[:, :3, :3])\n", - "poses_rot_init = r.as_quat()\n", - "poses_pos_init = pose_enums[:, :3, -1].cpu().numpy() \n", - "\n", - "poses_rot_opt = torch.tensor(np.divide(poses_rot_init, (np.sum(poses_rot_init**2, axis=1)**0.5)[:, None]), dtype=torch.float32, device='cuda', requires_grad=True)\n", - "poses_pos_opt = torch.tensor(poses_pos_init, dtype=torch.float32, device='cuda', requires_grad=True)\n", + "poses_rot_init = r.as_quat()\n", + "poses_pos_init = pose_enums[:, :3, -1].cpu().numpy()\n", + "\n", + "poses_rot_opt = torch.tensor(\n", + " np.divide(poses_rot_init, (np.sum(poses_rot_init**2, axis=1) ** 0.5)[:, None]),\n", + " dtype=torch.float32,\n", + " device=\"cuda\",\n", + " requires_grad=True,\n", + ")\n", + "poses_pos_opt = torch.tensor(\n", + " poses_pos_init, dtype=torch.float32, device=\"cuda\", requires_grad=True\n", + ")\n", "print(pose_rot_opt, pose_pos_opt)\n", "\n", "poses_opt = q_v_to_mtx_batch(poses_rot_opt, poses_pos_opt)\n", "\n", "# Initial opt render\n", - "rast_opts = render_multiple(glctx, torch.matmul(mvp, poses_opt), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_opts = rast_opts[4].detach().cpu().numpy()\n", + "rast_opts = render_multiple(\n", + " glctx, torch.matmul(mvp, poses_opt), vtx_pos, pos_idx, vtx_col, col_idx, resolution\n", + ")\n", + "img_opts = rast_opts[4].detach().cpu().numpy()\n", "\n", - "# initialize loss \n", - "loss_best = np.inf\n", + "# initialize loss\n", + "loss_best = np.inf\n", "\n", "# Visualize initial state\n", "print(f\"target pose={pose_target},\\ncurrent pose={poses_opt[4]}\")\n", "\n", - "b.hstack_images([\n", - " b.get_depth_image(img_opts[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n" + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opts[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -785,14 +1013,24 @@ "outputs": [], "source": [ "_, (ax_r, ax_t) = generate_rotation_translation_plot(False)\n", - "plot_rot_and_pos(pose_target.detach().cpu()[:3, :3] @ UNIT_VECTOR, \n", - " pose_target.detach().cpu()[:3, -1], \n", - " ax_r, ax_t, \n", - " color=\"red\", alpha=1, label=\"Target\")\n", - "plot_rot_and_pos(np.einsum('nij,j... -> ni', poses_opt.detach().cpu()[:, :3, :3], UNIT_VECTOR), \n", - " poses_opt.detach().cpu()[:, :3, -1], \n", - " ax_r, ax_t, \n", - " color=\"blue\", alpha=0.1, label=\"Hypothesis\")" + "plot_rot_and_pos(\n", + " pose_target.detach().cpu()[:3, :3] @ UNIT_VECTOR,\n", + " pose_target.detach().cpu()[:3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"red\",\n", + " alpha=1,\n", + " label=\"Target\",\n", + ")\n", + "plot_rot_and_pos(\n", + " np.einsum(\"nij,j... -> ni\", poses_opt.detach().cpu()[:, :3, :3], UNIT_VECTOR),\n", + " poses_opt.detach().cpu()[:, :3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"blue\",\n", + " alpha=0.1,\n", + " label=\"Hypothesis\",\n", + ")" ] }, { @@ -810,40 +1048,73 @@ "metadata": {}, "outputs": [], "source": [ - "def descend_gradient_multi(pose_rot_opt, pose_pos_opt, pose_target, rast_opts, rast_target, it=20, verbose=False, plot=False):\n", + "def descend_gradient_multi(\n", + " pose_rot_opt,\n", + " pose_pos_opt,\n", + " pose_target,\n", + " rast_opts,\n", + " rast_target,\n", + " it=20,\n", + " verbose=False,\n", + " plot=False,\n", + "):\n", " OPTIM_GIF_IMGS = []\n", " loss_best = np.inf\n", - " optimizer = torch.optim.Adam([pose_rot_opt, pose_pos_opt], betas=(0.9, 0.999), lr=2e-7)\n", + " optimizer = torch.optim.Adam(\n", + " [pose_rot_opt, pose_pos_opt], betas=(0.9, 0.999), lr=2e-7\n", + " )\n", " img_target = rast_target[0].detach().cpu().numpy()\n", - " img_target_viz = b.get_depth_image(img_target[:,:,0]* 255.0)\n", - " \n", - " if plot: \n", + " img_target_viz = b.get_depth_image(img_target[:, :, 0] * 255.0)\n", + "\n", + " if plot:\n", " fig, (ax_r, ax_t) = generate_rotation_translation_plot()\n", - " \n", + "\n", " poses_opt = q_v_to_mtx_batch(poses_rot_opt, poses_pos_opt)\n", "\n", - " plot_rot_and_pos(np.einsum('nij,j... -> ni', poses_opt.detach().cpu()[:, :3, :3], UNIT_VECTOR), \n", - " poses_opt.detach().cpu()[:, :3, -1], \n", - " ax_r, ax_t, \n", - " color=\"green\", alpha=0.1, label=\"Initial\")\n", - " plot_rot_and_pos(np.einsum('nij,j... -> ni', poses_opt.detach().cpu()[:, :3, :3], UNIT_VECTOR),\n", - " poses_opt.detach().cpu()[:, :3, -1], \n", - " ax_r, ax_t, \n", - " color=\"blue\", alpha=0.1, label=\"Hypothesis\")\n", - " plot_rot_and_pos(pose_target.detach().cpu()[:3, :3] @ UNIT_VECTOR, \n", - " pose_target.detach().cpu()[:3, -1], \n", - " ax_r, ax_t, \n", - " color=\"red\", alpha=1, label=\"Target\")\n", - "\n", - " \n", + " plot_rot_and_pos(\n", + " np.einsum(\n", + " \"nij,j... -> ni\", poses_opt.detach().cpu()[:, :3, :3], UNIT_VECTOR\n", + " ),\n", + " poses_opt.detach().cpu()[:, :3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"green\",\n", + " alpha=0.1,\n", + " label=\"Initial\",\n", + " )\n", + " plot_rot_and_pos(\n", + " np.einsum(\n", + " \"nij,j... -> ni\", poses_opt.detach().cpu()[:, :3, :3], UNIT_VECTOR\n", + " ),\n", + " poses_opt.detach().cpu()[:, :3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"blue\",\n", + " alpha=0.1,\n", + " label=\"Hypothesis\",\n", + " )\n", + " plot_rot_and_pos(\n", + " pose_target.detach().cpu()[:3, :3] @ UNIT_VECTOR,\n", + " pose_target.detach().cpu()[:3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"red\",\n", + " alpha=1,\n", + " label=\"Target\",\n", + " )\n", + "\n", " # TODO better convergence condition\n", " for i in tqdm(range(it)):\n", - " # noise = q_unit()\n", - " poses_rot_total_opt = poses_rot_opt #q_mul_torch(pose_rot_opt, noise)\n", - " mtx_total_opt = torch.matmul(mvp, q_v_to_mtx_batch(poses_rot_total_opt, poses_pos_opt))\n", - " color_opts = render_multiple(glctx, mtx_total_opt, vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "\n", - " diff = (rast_opts - rast_target)**2 # L2 norm.\n", + " # noise = q_unit()\n", + " poses_rot_total_opt = poses_rot_opt # q_mul_torch(pose_rot_opt, noise)\n", + " mtx_total_opt = torch.matmul(\n", + " mvp, q_v_to_mtx_batch(poses_rot_total_opt, poses_pos_opt)\n", + " )\n", + " color_opts = render_multiple(\n", + " glctx, mtx_total_opt, vtx_pos, pos_idx, vtx_col, col_idx, resolution\n", + " )\n", + "\n", + " diff = (rast_opts - rast_target) ** 2 # L2 norm.\n", " diff = torch.tanh(5.0 * torch.max(diff, dim=-1)[0])\n", " loss = torch.mean(diff)\n", " loss_val = float(loss)\n", @@ -855,29 +1126,50 @@ " loss.backward()\n", " optimizer.step()\n", "\n", - " # with torch.no_grad():\n", - " # pose_rot_opt /= torch.sum(poses_rot_opt**2, axis=1)**0.5\n", + " # with torch.no_grad():\n", + " # pose_rot_opt /= torch.sum(poses_rot_opt**2, axis=1)**0.5\n", + "\n", + " rast_opts = render_multiple(\n", + " glctx,\n", + " torch.matmul(mvp, q_v_to_mtx_batch(poses_rot_opt, poses_pos_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + " )\n", + " img_opts = rast_opts.detach().cpu().numpy()\n", "\n", - " rast_opts = render_multiple(glctx, torch.matmul(mvp, q_v_to_mtx_batch(poses_rot_opt, poses_pos_opt)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - " img_opts = rast_opts.detach().cpu().numpy()\n", - " \n", " if verbose:\n", " print(f\"loss={loss}, pos[0]={pose_pos_opt[0]}\")\n", "\n", - " curr_render_imgs = b.hvstack_images([b.get_depth_image(img_opts[i][:,:,0]* 255.0) for i in range(len(rast_opts))], \n", - " fibonacci_sphere_points*num_x,\n", - " num_planar_angle_points*num_y*num_z,\n", - " border=10)\n", - " curr_render_imgs = b.vstack_images([img_target_viz, b.scale_image(curr_render_imgs, 0.3)])\n", + " curr_render_imgs = b.hvstack_images(\n", + " [\n", + " b.get_depth_image(img_opts[i][:, :, 0] * 255.0)\n", + " for i in range(len(rast_opts))\n", + " ],\n", + " fibonacci_sphere_points * num_x,\n", + " num_planar_angle_points * num_y * num_z,\n", + " border=10,\n", + " )\n", + " curr_render_imgs = b.vstack_images(\n", + " [img_target_viz, b.scale_image(curr_render_imgs, 0.3)]\n", + " )\n", "\n", " if plot:\n", - " poses_opt_curr_val = q_v_to_mtx_batch(poses_rot_opt, poses_pos_opt).detach().cpu() \n", - " plot_rot_and_pos(np.einsum('nij,j... -> ni', poses_opt_curr_val[:, :3, :3], UNIT_VECTOR), \n", - " poses_opt_curr_val[:, :3, -1], \n", - " ax_r, ax_t, \n", - " color=\"blue\", alpha=0.1,\n", - " rot_title=f\"Rotation evolution, iter {i}\", \n", - " pos_title=f\"Translation evolution, iter {i}\") # current\n", + " poses_opt_curr_val = (\n", + " q_v_to_mtx_batch(poses_rot_opt, poses_pos_opt).detach().cpu()\n", + " )\n", + " plot_rot_and_pos(\n", + " np.einsum(\"nij,j... -> ni\", poses_opt_curr_val[:, :3, :3], UNIT_VECTOR),\n", + " poses_opt_curr_val[:, :3, -1],\n", + " ax_r,\n", + " ax_t,\n", + " color=\"blue\",\n", + " alpha=0.1,\n", + " rot_title=f\"Rotation evolution, iter {i}\",\n", + " pos_title=f\"Translation evolution, iter {i}\",\n", + " ) # current\n", " curr_fig = fig2img(fig)\n", " OPTIM_GIF_IMGS.append(b.hstack_images([curr_fig, curr_render_imgs]))\n", " else:\n", @@ -885,9 +1177,17 @@ "\n", " return OPTIM_GIF_IMGS\n", "\n", - "OPTIM_GIF_IMGS = descend_gradient_multi(poses_rot_opt, poses_pos_opt, pose_target, \n", - " rast_opts, rast_target, \n", - " it=30, verbose=False, plot=True)\n", + "\n", + "OPTIM_GIF_IMGS = descend_gradient_multi(\n", + " poses_rot_opt,\n", + " poses_pos_opt,\n", + " pose_target,\n", + " rast_opts,\n", + " rast_target,\n", + " it=30,\n", + " verbose=False,\n", + " plot=True,\n", + ")\n", "b.viz.make_gif_from_pil_images(OPTIM_GIF_IMGS, \"render_imgs_multi.gif\")\n", "b.vstack_images([OPTIM_GIF_IMGS[0], OPTIM_GIF_IMGS[-1]])" ] diff --git a/scripts/experiments/icra/camera_pose_tracking/object_tracking.ipynb b/scripts/experiments/icra/camera_pose_tracking/object_tracking.ipynb index 33d908eb..013f72a5 100644 --- a/scripts/experiments/icra/camera_pose_tracking/object_tracking.ipynb +++ b/scripts/experiments/icra/camera_pose_tracking/object_tracking.ipynb @@ -20,8 +20,7 @@ "import os\n", "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", - "\n" + "# jax.config.update('jax_enable_checks', True)" ] }, { @@ -47,40 +46,48 @@ " jnp.array([0.0, 0.0, 1.0]),\n", ")\n", "\n", - "camera_poses = jnp.array([\n", - " b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) @ camera_pose\n", - " for angle in jnp.linspace(0, 2*jnp.pi, 120)]\n", + "camera_poses = jnp.array(\n", + " [\n", + " b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) @ camera_pose\n", + " for angle in jnp.linspace(0, 2 * jnp.pi, 120)\n", + " ]\n", ")\n", "\n", "poses = jnp.linalg.inv(camera_poses)\n", "\n", - "translation_deltas = b.utils.make_translation_grid_enumeration(-0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 11, 11, 11)\n", - "rotation_deltas = jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0))(\n", - " jax.random.split(jax.random.PRNGKey(3), 500)\n", + "translation_deltas = b.utils.make_translation_grid_enumeration(\n", + " -0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 11, 11, 11\n", + ")\n", + "rotation_deltas = jax.vmap(\n", + " lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0)\n", + ")(jax.random.split(jax.random.PRNGKey(3), 500))\n", + "\n", + "likelihood = jax.vmap(\n", + " b.threedp3_likelihood_old, in_axes=(None, 0, None, None, None, None, None)\n", ")\n", "\n", - "likelihood = jax.vmap(b.threedp3_likelihood_old, in_axes=(None, 0, None, None, None, None, None))\n", "\n", "def update_pose_estimate(pose_estimate, gt_image):\n", " proposals = jnp.einsum(\"ij,ajk->aik\", pose_estimate, translation_deltas)\n", - " rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(proposals[:,None, ...], jnp.array([0]))\n", + " rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(\n", + " proposals[:, None, ...], jnp.array([0])\n", + " )\n", " weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)\n", " pose_estimate = proposals[jnp.argmax(weights_new)]\n", "\n", " proposals = jnp.einsum(\"ij,ajk->aik\", pose_estimate, rotation_deltas)\n", - " rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(proposals[:, None, ...], jnp.array([0]))\n", + " rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(\n", + " proposals[:, None, ...], jnp.array([0])\n", + " )\n", " weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)\n", " pose_estimate = proposals[jnp.argmax(weights_new)]\n", " return pose_estimate, pose_estimate\n", "\n", - "inference_program = jax.jit(lambda p,x: jax.lax.scan(update_pose_estimate, p,x)[1])\n", + "\n", + "inference_program = jax.jit(lambda p, x: jax.lax.scan(update_pose_estimate, p, x)[1])\n", "\n", "original_intrinsics = b.Intrinsics(\n", - " height=200,\n", - " width=200,\n", - " fx=150.0, fy=150.0,\n", - " cx=100.0, cy=100.0,\n", - " near=0.001, far=6.0\n", + " height=200, width=200, fx=150.0, fy=150.0, cx=100.0, cy=100.0, near=0.001, far=6.0\n", ")" ] }, @@ -129,39 +136,54 @@ " dataa = []\n", " for SCALING_FACTOR_IDX in range(len(scaling_factors)):\n", " print(SCALING_FACTOR_IDX)\n", - " intrinsics = b.scale_camera_parameters(original_intrinsics, scaling_factors[SCALING_FACTOR_IDX])\n", - " \n", + " intrinsics = b.scale_camera_parameters(\n", + " original_intrinsics, scaling_factors[SCALING_FACTOR_IDX]\n", + " )\n", + "\n", " b.setup_renderer(intrinsics)\n", - " model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(object_ids[OBJECT_ID_IDX] + 1).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/100.0)\n", - " \n", - " observed_images = b.RENDERER.render_many(poses[:,None,...], jnp.array([0]))\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + " mesh_path = os.path.join(\n", + " model_dir,\n", + " \"obj_\" + \"{}\".format(object_ids[OBJECT_ID_IDX] + 1).rjust(6, \"0\") + \".ply\",\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 100.0)\n", + "\n", + " observed_images = b.RENDERER.render_many(poses[:, None, ...], jnp.array([0]))\n", " print(\"observed_images.shape\", observed_images.shape)\n", - " \n", + "\n", " inferred_poses = inference_program(poses[0], observed_images)\n", - " \n", + "\n", " start = time.time()\n", " pose_estimates_over_time = inference_program(poses[0], observed_images)\n", " end = time.time()\n", - " print (\"Time elapsed:\", end - start)\n", - " fps = poses.shape[0] / (end - start)\n", - " print (\"FPS:\", poses.shape[0] / (end - start))\n", - " dataa.append((scaling_factors[SCALING_FACTOR_IDX], object_ids[OBJECT_ID_IDX], intrinsics.height, fps, pose_estimates_over_time))\n", - " \n", + " print(\"Time elapsed:\", end - start)\n", + " fps = poses.shape[0] / (end - start)\n", + " print(\"FPS:\", poses.shape[0] / (end - start))\n", + " dataa.append(\n", + " (\n", + " scaling_factors[SCALING_FACTOR_IDX],\n", + " object_ids[OBJECT_ID_IDX],\n", + " intrinsics.height,\n", + " fps,\n", + " pose_estimates_over_time,\n", + " )\n", + " )\n", + "\n", " max_depth = 10.0\n", - " rerendered_images = b.RENDERER.render_many(pose_estimates_over_time[:, None, ...], jnp.array([0]))\n", + " rerendered_images = b.RENDERER.render_many(\n", + " pose_estimates_over_time[:, None, ...], jnp.array([0])\n", + " )\n", " viz_images = []\n", - " for (r, d) in zip(rerendered_images, observed_images):\n", - " viz_r = b.viz.scale_image(b.viz.get_depth_image(r[:,:,2]), 5.0)\n", - " viz_d = b.viz.scale_image(b.viz.get_depth_image(d[:,:,2]), 5.0)\n", - " overlay = b.viz.overlay_image(viz_r,viz_d)\n", - " viz_images.append(b.viz.multi_panel(\n", - " [\n", - " viz_d, viz_r, overlay\n", - " ],\n", - " [\"Ground Truth\", \"Inferred Reconstruction\", \"Overlay\"],\n", - " ))\n", + " for r, d in zip(rerendered_images, observed_images):\n", + " viz_r = b.viz.scale_image(b.viz.get_depth_image(r[:, :, 2]), 5.0)\n", + " viz_d = b.viz.scale_image(b.viz.get_depth_image(d[:, :, 2]), 5.0)\n", + " overlay = b.viz.overlay_image(viz_r, viz_d)\n", + " viz_images.append(\n", + " b.viz.multi_panel(\n", + " [viz_d, viz_r, overlay],\n", + " [\"Ground Truth\", \"Inferred Reconstruction\", \"Overlay\"],\n", + " )\n", + " )\n", "\n", " b.make_gif_from_pil_images(viz_images, \"demo.gif\")\n", " data.append(dataa)" @@ -185,8 +207,8 @@ "outputs": [], "source": [ "def error_between_poses(pose_1, pose_2):\n", - " translation_error = jnp.linalg.norm(pose_1[:3,3] - pose_2[:3,3])\n", - " error_rotvec = R.from_matrix((pose_1 @ jnp.linalg.inv(pose_2))[:3,:3]).as_rotvec()\n", + " translation_error = jnp.linalg.norm(pose_1[:3, 3] - pose_2[:3, 3])\n", + " error_rotvec = R.from_matrix((pose_1 @ jnp.linalg.inv(pose_2))[:3, :3]).as_rotvec()\n", " rotation_error = jnp.rad2deg(jnp.linalg.norm(error_rotvec))\n", " return translation_error, rotation_error" ] @@ -200,10 +222,25 @@ "source": [ "output_string = \"\"\n", "for OBJECT_ID_IDX in range(len(object_ids)):\n", - " for SCALING_FACTOR_IDX in range(len(scaling_factors)-1,-1,-1):\n", - " scaling_factor, object_id, resolution, fps, poses_inferred = data[OBJECT_ID_IDX][SCALING_FACTOR_IDX]\n", - " errors = jnp.array([error_between_poses(p,t) for (p,t) in zip(poses_inferred, poses)])\n", - " print(object_id, \" & \", resolution, \" & \",f\"{fps:0.3f}\", \" & \", f\"{(float(errors[:,0].mean() * 10.0)):0.3f}\", \" & \", f\"{(float(errors[:,1].mean() * 1.0)):0.3f}\", \"\\\\\\\\\")" + " for SCALING_FACTOR_IDX in range(len(scaling_factors) - 1, -1, -1):\n", + " scaling_factor, object_id, resolution, fps, poses_inferred = data[\n", + " OBJECT_ID_IDX\n", + " ][SCALING_FACTOR_IDX]\n", + " errors = jnp.array(\n", + " [error_between_poses(p, t) for (p, t) in zip(poses_inferred, poses)]\n", + " )\n", + " print(\n", + " object_id,\n", + " \" & \",\n", + " resolution,\n", + " \" & \",\n", + " f\"{fps:0.3f}\",\n", + " \" & \",\n", + " f\"{(float(errors[:,0].mean() * 10.0)):0.3f}\",\n", + " \" & \",\n", + " f\"{(float(errors[:,1].mean() * 1.0)):0.3f}\",\n", + " \"\\\\\\\\\",\n", + " )" ] }, { @@ -229,8 +266,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "observed_images = b.RENDERER.render_many(poses[:,None,...], jnp.array([0]))\n", + "observed_images = b.RENDERER.render_many(poses[:, None, ...], jnp.array([0]))\n", "print(\"observed_images.shape\", observed_images.shape)\n", "\n", "inferred_poses = inference_program(poses[0], observed_images)\n", @@ -238,8 +274,8 @@ "start = time.time()\n", "pose_estimates_over_time = inference_program(poses[0], observed_images)\n", "end = time.time()\n", - "print (\"Time elapsed:\", end - start)\n", - "print (\"FPS:\", poses.shape[0] / (end - start))\n" + "print(\"Time elapsed:\", end - start)\n", + "print(\"FPS:\", poses.shape[0] / (end - start))" ] }, { @@ -249,7 +285,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_depth_image(observed_images[0][:,:,2])" + "b.get_depth_image(observed_images[0][:, :, 2])" ] }, { diff --git a/scripts/experiments/icra/camera_pose_tracking/pose.ipynb b/scripts/experiments/icra/camera_pose_tracking/pose.ipynb index c5b7f0bc..2ab539cc 100644 --- a/scripts/experiments/icra/camera_pose_tracking/pose.ipynb +++ b/scripts/experiments/icra/camera_pose_tracking/pose.ipynb @@ -18,7 +18,7 @@ "import bayes3d as b\n", "from tqdm import tqdm\n", "\n", - "import nvdiffrast.torch as dr\n" + "import nvdiffrast.torch as dr" ] }, { @@ -28,22 +28,22 @@ "metadata": {}, "outputs": [], "source": [ - "max_iter = 10000\n", - "repeats = 1\n", - "log_interval = 10\n", - "display_interval = None\n", - "display_res = 512\n", - "lr_base = 0.01\n", - "lr_falloff = 1.0\n", - "nr_base = 1.0\n", - "nr_falloff = 1e-4\n", - "grad_phase_start = 0.5\n", - "resolution = 256\n", - "out_dir = None\n", - "log_fn = None\n", - "mp4save_interval = None\n", - "mp4save_fn = None\n", - "use_opengl = False" + "max_iter = 10000\n", + "repeats = 1\n", + "log_interval = 10\n", + "display_interval = None\n", + "display_res = 512\n", + "lr_base = 0.01\n", + "lr_falloff = 1.0\n", + "nr_base = 1.0\n", + "nr_falloff = 1e-4\n", + "grad_phase_start = 0.5\n", + "resolution = 256\n", + "out_dir = None\n", + "log_fn = None\n", + "mp4save_interval = None\n", + "mp4save_fn = None\n", + "use_opengl = False" ] }, { @@ -53,8 +53,11 @@ "metadata": {}, "outputs": [], "source": [ - "glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext()\n", - "mvp = torch.tensor(np.matmul(util.projection(x=0.4), util.translate(0, 0, -3.5)).astype(np.float32), device='cuda')\n" + "glctx = dr.RasterizeGLContext() # if use_opengl else dr.RasterizeCudaContext()\n", + "mvp = torch.tensor(\n", + " np.matmul(util.projection(x=0.4), util.translate(0, 0, -3.5)).astype(np.float32),\n", + " device=\"cuda\",\n", + ")" ] }, { @@ -64,33 +67,65 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "#----------------------------------------------------------------------------\n", + "# ----------------------------------------------------------------------------\n", "# Quaternion math.\n", - "#----------------------------------------------------------------------------\n", + "# ----------------------------------------------------------------------------\n", "\n", "# Unit quaternion.\n", "def q_unit():\n", " return np.asarray([1, 0, 0, 0], np.float32)\n", "\n", + "\n", "# Get a random normalized quaternion.\n", "def q_rnd():\n", " u, v, w = np.random.uniform(0.0, 1.0, size=[3])\n", " v *= 2.0 * np.pi\n", " w *= 2.0 * np.pi\n", - " return np.asarray([(1.0-u)**0.5 * np.sin(v), (1.0-u)**0.5 * np.cos(v), u**0.5 * np.sin(w), u**0.5 * np.cos(w)], np.float32)\n", + " return np.asarray(\n", + " [\n", + " (1.0 - u) ** 0.5 * np.sin(v),\n", + " (1.0 - u) ** 0.5 * np.cos(v),\n", + " u**0.5 * np.sin(w),\n", + " u**0.5 * np.cos(w),\n", + " ],\n", + " np.float32,\n", + " )\n", + "\n", "\n", "# Get a random quaternion from the octahedral symmetric group S_4.\n", "_r2 = 0.5**0.5\n", - "_q_S4 = [[ 1.0, 0.0, 0.0, 0.0], [ 0.0, 1.0, 0.0, 0.0], [ 0.0, 0.0, 1.0, 0.0], [ 0.0, 0.0, 0.0, 1.0],\n", - " [-0.5, 0.5, 0.5, 0.5], [-0.5,-0.5,-0.5, 0.5], [ 0.5,-0.5, 0.5, 0.5], [ 0.5, 0.5,-0.5, 0.5],\n", - " [ 0.5, 0.5, 0.5, 0.5], [-0.5, 0.5,-0.5, 0.5], [ 0.5,-0.5,-0.5, 0.5], [-0.5,-0.5, 0.5, 0.5],\n", - " [ _r2,-_r2, 0.0, 0.0], [ _r2, _r2, 0.0, 0.0], [ 0.0, 0.0, _r2, _r2], [ 0.0, 0.0,-_r2, _r2],\n", - " [ 0.0, _r2, _r2, 0.0], [ _r2, 0.0, 0.0,-_r2], [ _r2, 0.0, 0.0, _r2], [ 0.0,-_r2, _r2, 0.0],\n", - " [ _r2, 0.0, _r2, 0.0], [ 0.0, _r2, 0.0, _r2], [ _r2, 0.0,-_r2, 0.0], [ 0.0,-_r2, 0.0, _r2]]\n", + "_q_S4 = [\n", + " [1.0, 0.0, 0.0, 0.0],\n", + " [0.0, 1.0, 0.0, 0.0],\n", + " [0.0, 0.0, 1.0, 0.0],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " [-0.5, 0.5, 0.5, 0.5],\n", + " [-0.5, -0.5, -0.5, 0.5],\n", + " [0.5, -0.5, 0.5, 0.5],\n", + " [0.5, 0.5, -0.5, 0.5],\n", + " [0.5, 0.5, 0.5, 0.5],\n", + " [-0.5, 0.5, -0.5, 0.5],\n", + " [0.5, -0.5, -0.5, 0.5],\n", + " [-0.5, -0.5, 0.5, 0.5],\n", + " [_r2, -_r2, 0.0, 0.0],\n", + " [_r2, _r2, 0.0, 0.0],\n", + " [0.0, 0.0, _r2, _r2],\n", + " [0.0, 0.0, -_r2, _r2],\n", + " [0.0, _r2, _r2, 0.0],\n", + " [_r2, 0.0, 0.0, -_r2],\n", + " [_r2, 0.0, 0.0, _r2],\n", + " [0.0, -_r2, _r2, 0.0],\n", + " [_r2, 0.0, _r2, 0.0],\n", + " [0.0, _r2, 0.0, _r2],\n", + " [_r2, 0.0, -_r2, 0.0],\n", + " [0.0, -_r2, 0.0, _r2],\n", + "]\n", + "\n", + "\n", "def q_rnd_S4():\n", " return np.asarray(_q_S4[np.random.randint(24)], np.float32)\n", "\n", + "\n", "# Quaternion slerp.\n", "def q_slerp(p, q, t):\n", " d = np.dot(p, q)\n", @@ -98,28 +133,31 @@ " q = -q\n", " d = -d\n", " if d > 0.999:\n", - " a = p + t * (q-p)\n", + " a = p + t * (q - p)\n", " return a / np.linalg.norm(a)\n", " t0 = np.arccos(d)\n", " tt = t0 * t\n", " st = np.sin(tt)\n", " st0 = np.sin(t0)\n", " s1 = st / st0\n", - " s0 = np.cos(tt) - d*s1\n", - " return s0*p + s1*q\n", + " s0 = np.cos(tt) - d * s1\n", + " return s0 * p + s1 * q\n", + "\n", "\n", "# Quaterion scale (slerp vs. identity quaternion).\n", "def q_scale(q, scl):\n", " return q_slerp(q_unit(), q, scl)\n", "\n", + "\n", "# Quaternion product.\n", "def q_mul(p, q):\n", " s1, V1 = p[0], p[1:]\n", " s2, V2 = q[0], q[1:]\n", - " s = s1*s2 - np.dot(V1, V2)\n", - " V = s1*V2 + s2*V1 + np.cross(V1, V2)\n", + " s = s1 * s2 - np.dot(V1, V2)\n", + " V = s1 * V2 + s2 * V1 + np.cross(V1, V2)\n", " return np.asarray([s, V[0], V[1], V[2]], np.float32)\n", "\n", + "\n", "# Angular difference between two quaternions in degrees.\n", "def q_angle_deg(p, q):\n", " p = p.detach().cpu().numpy()\n", @@ -128,24 +166,49 @@ " d = min(d, 1.0)\n", " return np.degrees(2.0 * np.arccos(d))\n", "\n", + "\n", "# Quaternion product\n", "def q_mul_torch(p, q):\n", - " a = p[0]*q[0] - p[1]*q[1] - p[2]*q[2] - p[3]*q[3]\n", - " b = p[0]*q[1] + p[1]*q[0] + p[2]*q[3] - p[3]*q[2]\n", - " c = p[0]*q[2] + p[2]*q[0] + p[3]*q[1] - p[1]*q[3]\n", - " d = p[0]*q[3] + p[3]*q[0] + p[1]*q[2] - p[2]*q[1]\n", + " a = p[0] * q[0] - p[1] * q[1] - p[2] * q[2] - p[3] * q[3]\n", + " b = p[0] * q[1] + p[1] * q[0] + p[2] * q[3] - p[3] * q[2]\n", + " c = p[0] * q[2] + p[2] * q[0] + p[3] * q[1] - p[1] * q[3]\n", + " d = p[0] * q[3] + p[3] * q[0] + p[1] * q[2] - p[2] * q[1]\n", " return torch.stack([a, b, c, d])\n", "\n", + "\n", "# Convert quaternion to 4x4 rotation matrix.\n", "def q_to_mtx(q):\n", - " r0 = torch.stack([1.0-2.0*q[1]**2 - 2.0*q[2]**2, 2.0*q[0]*q[1] - 2.0*q[2]*q[3], 2.0*q[0]*q[2] + 2.0*q[1]*q[3]])\n", - " r1 = torch.stack([2.0*q[0]*q[1] + 2.0*q[2]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[2]**2, 2.0*q[1]*q[2] - 2.0*q[0]*q[3]])\n", - " r2 = torch.stack([2.0*q[0]*q[2] - 2.0*q[1]*q[3], 2.0*q[1]*q[2] + 2.0*q[0]*q[3], 1.0 - 2.0*q[0]**2 - 2.0*q[1]**2])\n", + " r0 = torch.stack(\n", + " [\n", + " 1.0 - 2.0 * q[1] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[0] * q[1] - 2.0 * q[2] * q[3],\n", + " 2.0 * q[0] * q[2] + 2.0 * q[1] * q[3],\n", + " ]\n", + " )\n", + " r1 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[1] + 2.0 * q[2] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[2] ** 2,\n", + " 2.0 * q[1] * q[2] - 2.0 * q[0] * q[3],\n", + " ]\n", + " )\n", + " r2 = torch.stack(\n", + " [\n", + " 2.0 * q[0] * q[2] - 2.0 * q[1] * q[3],\n", + " 2.0 * q[1] * q[2] + 2.0 * q[0] * q[3],\n", + " 1.0 - 2.0 * q[0] ** 2 - 2.0 * q[1] ** 2,\n", + " ]\n", + " )\n", " rr = torch.transpose(torch.stack([r0, r1, r2]), 1, 0)\n", - " rr = torch.cat([rr, torch.tensor([[0], [0], [0]], dtype=torch.float32).cuda()], dim=1) # Pad right column.\n", - " rr = torch.cat([rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0) # Pad bottom row.\n", + " rr = torch.cat(\n", + " [rr, torch.tensor([[0], [0], [0]], dtype=torch.float32).cuda()], dim=1\n", + " ) # Pad right column.\n", + " rr = torch.cat(\n", + " [rr, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32).cuda()], dim=0\n", + " ) # Pad bottom row.\n", " return rr\n", "\n", + "\n", "# Transform vertex positions to clip space\n", "def transform_pos(mtx, pos):\n", " t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx, np.ndarray) else mtx\n", @@ -153,16 +216,21 @@ " posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)\n", " return torch.matmul(posw, t_mtx.t())[None, ...]\n", "\n", + "\n", "def render(glctx, mtx, pos, pos_idx, col, col_idx, resolution: int):\n", " # Setup TF graph for reference.\n", " depth_ = pos[..., 2:3]\n", - " depth = torch.tensor([[[(z_val/1)] for z_val in depth_.squeeze()]], dtype=torch.float32).cuda()\n", - " pos_clip = transform_pos(mtx, pos)\n", - " rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution=[resolution, resolution])\n", - " color , _ = dr.interpolate(depth, rast_out, pos_idx)\n", + " depth = torch.tensor(\n", + " [[[(z_val / 1)] for z_val in depth_.squeeze()]], dtype=torch.float32\n", + " ).cuda()\n", + " pos_clip = transform_pos(mtx, pos)\n", + " rast_out, _ = dr.rasterize(\n", + " glctx, pos_clip, pos_idx, resolution=[resolution, resolution]\n", + " )\n", + " color, _ = dr.interpolate(depth, rast_out, pos_idx)\n", " # color = dr.antialias(color, rast_out, pos_clip, pos_idx)\n", " return color\n", - " # return rast_out[:,:,:,2:3]\n" + " # return rast_out[:,:,:,2:3]" ] }, { @@ -173,13 +241,14 @@ "outputs": [], "source": [ "datadir = \"/home/nishadgothoskar/bayes3d/nvdiffrast/samples/data/\"\n", - "with np.load(f'{datadir}/cube_p.npz') as f:\n", + "with np.load(f\"{datadir}/cube_p.npz\") as f:\n", " pos_idx, pos, col_idx, col = f.values()\n", "print(\"Mesh has %d triangles and %d vertices.\" % (pos_idx.shape[0], pos.shape[0]))\n", "\n", "# Some input geometry contains vertex positions in (N, 4) (with v[:,3]==1). Drop\n", "# the last column in that case.\n", - "if pos.shape[1] == 4: pos = pos[:, 0:3]\n", + "if pos.shape[1] == 4:\n", + " pos = pos[:, 0:3]\n", "\n", "# Create position/triangle index tensors\n", "pos_idx = torch.from_numpy(pos_idx.astype(np.int32)).cuda()\n", @@ -218,12 +287,22 @@ "metadata": {}, "outputs": [], "source": [ - "pose_target = torch.tensor(q_rnd(), device='cuda')\n", - "rast_target = render(glctx, torch.matmul(mvp, q_to_mtx(pose_target)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_target = rast_target[0].detach().cpu().numpy()\n", - "b.hstack_images([\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n" + "pose_target = torch.tensor(q_rnd(), device=\"cuda\")\n", + "rast_target = render(\n", + " glctx,\n", + " torch.matmul(mvp, q_to_mtx(pose_target)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + ")\n", + "img_target = rast_target[0].detach().cpu().numpy()\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -233,16 +312,31 @@ "metadata": {}, "outputs": [], "source": [ - "pose_init = pose_target.cpu().numpy() + 0.3\n", - "pose_opt = torch.tensor(pose_init / np.sum(pose_init**2)**0.5, dtype=torch.float32, device='cuda', requires_grad=True)\n", - "loss_best = np.inf\n", - "\n", - "rast_opt = render(glctx, torch.matmul(mvp, q_to_mtx(pose_opt)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - "img_opt = rast_opt[0].detach().cpu().numpy()\n", - "b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n" + "pose_init = pose_target.cpu().numpy() + 0.3\n", + "pose_opt = torch.tensor(\n", + " pose_init / np.sum(pose_init**2) ** 0.5,\n", + " dtype=torch.float32,\n", + " device=\"cuda\",\n", + " requires_grad=True,\n", + ")\n", + "loss_best = np.inf\n", + "\n", + "rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, q_to_mtx(pose_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + ")\n", + "img_opt = rast_opt[0].detach().cpu().numpy()\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + ")" ] }, { @@ -258,33 +352,45 @@ "for _ in tqdm(range(200)):\n", " noise = q_unit()\n", " pose_total_opt = q_mul_torch(pose_opt, noise)\n", - " mtx_total_opt = torch.matmul(mvp, q_to_mtx(pose_total_opt))\n", - " color_opt = render(glctx, mtx_total_opt, vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - " \n", - " diff = (rast_opt - rast_target)**2 # L2 norm.\n", + " mtx_total_opt = torch.matmul(mvp, q_to_mtx(pose_total_opt))\n", + " color_opt = render(\n", + " glctx, mtx_total_opt, vtx_pos, pos_idx, vtx_col, col_idx, resolution\n", + " )\n", + "\n", + " diff = (rast_opt - rast_target) ** 2 # L2 norm.\n", " diff = torch.tanh(5.0 * torch.max(diff, dim=-1)[0])\n", " loss = torch.mean(diff)\n", " loss_val = float(loss)\n", - " \n", + "\n", " if (loss_val < loss_best) and (loss_val > 0.0):\n", " loss_best = loss_val\n", - " \n", + "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(pose_opt.grad)\n", - " \n", + "\n", " with torch.no_grad():\n", - " pose_opt /= torch.sum(pose_opt**2)**0.5\n", - " \n", - " rast_opt = render(glctx, torch.matmul(mvp, q_to_mtx(pose_opt)), vtx_pos, pos_idx, vtx_col, col_idx, resolution)\n", - " img_opt = rast_opt[0].detach().cpu().numpy()\n", + " pose_opt /= torch.sum(pose_opt**2) ** 0.5\n", + "\n", + " rast_opt = render(\n", + " glctx,\n", + " torch.matmul(mvp, q_to_mtx(pose_opt)),\n", + " vtx_pos,\n", + " pos_idx,\n", + " vtx_col,\n", + " col_idx,\n", + " resolution,\n", + " )\n", + " img_opt = rast_opt[0].detach().cpu().numpy()\n", " images.append(\n", - "b.hstack_images([\n", - " b.get_depth_image(img_opt[:,:,0]* 255.0) ,\n", - " b.get_depth_image(img_target[:,:,0]* 255.0) ,\n", - "])\n", + " b.hstack_images(\n", + " [\n", + " b.get_depth_image(img_opt[:, :, 0] * 255.0),\n", + " b.get_depth_image(img_target[:, :, 0] * 255.0),\n", + " ]\n", + " )\n", " )" ] }, @@ -295,7 +401,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.vstack_images([images[0],images[-1]])" + "b.vstack_images([images[0], images[-1]])" ] }, { diff --git a/scripts/experiments/icra/fork_knife/fork-knife-cnn.ipynb b/scripts/experiments/icra/fork_knife/fork-knife-cnn.ipynb index 9ff47916..f9611d51 100644 --- a/scripts/experiments/icra/fork_knife/fork-knife-cnn.ipynb +++ b/scripts/experiments/icra/fork_knife/fork-knife-cnn.ipynb @@ -93,23 +93,22 @@ "outputs": [], "source": [ "class CNN(nn.Module):\n", - "\n", " @nn.compact\n", - " def __call__(self, imgs): # XXX todo: make sure the sizes all line up\n", - " x = nn.Conv(64, (10, 10), strides=(3, 3), padding='VALID')(imgs)\n", + " def __call__(self, imgs): # XXX todo: make sure the sizes all line up\n", + " x = nn.Conv(64, (10, 10), strides=(3, 3), padding=\"VALID\")(imgs)\n", " x = nn.activation.relu(x)\n", " x = nn.max_pool(x, (3, 3), strides=(2, 2))\n", - " \n", - " x = nn.Conv(128, (5, 5), strides=(2, 2), padding='VALID')(x)\n", + "\n", + " x = nn.Conv(128, (5, 5), strides=(2, 2), padding=\"VALID\")(x)\n", " x = nn.activation.relu(x)\n", " x = nn.max_pool(x, (2, 2), strides=(2, 2))\n", - " \n", + "\n", " x = nn.Dense(1024)(x.reshape(imgs.shape[0], -1))\n", " x = nn.activation.relu(x)\n", - " \n", + "\n", " x = nn.Dense(1024)(x)\n", " x = nn.activation.relu(x)\n", - " \n", + "\n", " x = nn.Dense(576)(x)\n", " x = nn.activation.relu(x)\n", "\n", @@ -120,7 +119,7 @@ " x = nn.activation.relu(x)\n", "\n", " x = nn.Dense(2)(x)\n", - " return x " + " return x" ] }, { @@ -131,7 +130,11 @@ "outputs": [], "source": [ "cnn = CNN()\n", - "cnn.tabulate(jax.random.PRNGKey(0), jnp.zeros((1, 100, 100, 1)), console_kwargs={'force_jupyter': True})" + "cnn.tabulate(\n", + " jax.random.PRNGKey(0),\n", + " jnp.zeros((1, 100, 100, 1)),\n", + " console_kwargs={\"force_jupyter\": True},\n", + ")" ] }, { @@ -143,7 +146,9 @@ "source": [ "def create_train_state(module, rng, learning_rate, momentum):\n", " \"\"\"Creates an initial `TrainState`.\"\"\"\n", - " params = module.init(rng, jnp.ones([BATCH_SIZE, 100, 100, 1]))['params'] # initialize parameters by passing a template image\n", + " params = module.init(rng, jnp.ones([BATCH_SIZE, 100, 100, 1]))[\n", + " \"params\"\n", + " ] # initialize parameters by passing a template image\n", " tx = optax.sgd(learning_rate, momentum)\n", " return TrainState.create(apply_fn=module.apply, params=params, tx=tx)" ] @@ -156,12 +161,14 @@ "outputs": [], "source": [ "@jax.jit\n", - "def train_step(state, img_batch, label_batch): \n", + "def train_step(state, img_batch, label_batch):\n", " \"\"\"Train for a single step.\"\"\"\n", + "\n", " def loss_fn(params):\n", - " logits = state.apply_fn({'params': params}, img_batch, mutable=False)\n", + " logits = state.apply_fn({\"params\": params}, img_batch, mutable=False)\n", " loss = optax.softmax_cross_entropy(logits=logits, labels=label_batch).mean()\n", " return loss\n", + "\n", " grad_fn = jax.grad(loss_fn)\n", " grads = grad_fn(state.params)\n", " state = state.apply_gradients(grads=grads)\n", @@ -178,15 +185,25 @@ "@jax.jit\n", "def calc_loss(train_state, imgs, labels):\n", " return optax.softmax_cross_entropy(\n", - " #logits= train_state.apply_fn({'params': train_state.params}, imgs),\n", - " logits= train_state.apply_fn({'params': train_state.params}, imgs, mutable=False),\n", - " labels=labels).mean()\n", + " # logits= train_state.apply_fn({'params': train_state.params}, imgs),\n", + " logits=train_state.apply_fn(\n", + " {\"params\": train_state.params}, imgs, mutable=False\n", + " ),\n", + " labels=labels,\n", + " ).mean()\n", + "\n", "\n", "def calc_loss_batched(train_state, imgs, labels):\n", " n_batches = imgs.shape[0] // BATCH_SIZE\n", - " return jnp.array([calc_loss(train_state, img_batch, label_batch)\n", - " for (img_batch, label_batch) in zip(imgs.reshape(n_batches, BATCH_SIZE, 100, 100, 1),\n", - " labels.reshape(n_batches, BATCH_SIZE, 2))]).mean()" + " return jnp.array(\n", + " [\n", + " calc_loss(train_state, img_batch, label_batch)\n", + " for (img_batch, label_batch) in zip(\n", + " imgs.reshape(n_batches, BATCH_SIZE, 100, 100, 1),\n", + " labels.reshape(n_batches, BATCH_SIZE, 2),\n", + " )\n", + " ]\n", + " ).mean()" ] }, { @@ -196,8 +213,8 @@ "metadata": {}, "outputs": [], "source": [ - "train_data_file = jnp.load('train_data.npz')\n", - "test_data_file = jnp.load('test_data.npz')" + "train_data_file = jnp.load(\"train_data.npz\")\n", + "test_data_file = jnp.load(\"test_data.npz\")" ] }, { @@ -207,8 +224,8 @@ "metadata": {}, "outputs": [], "source": [ - "train_imgs = train_data_file['arr_0']\n", - "train_labels = train_data_file['arr_1']" + "train_imgs = train_data_file[\"arr_0\"]\n", + "train_labels = train_data_file[\"arr_1\"]" ] }, { @@ -228,8 +245,8 @@ "metadata": {}, "outputs": [], "source": [ - "test_imgs = test_data_file['arr_0']\n", - "test_labels = test_data_file['arr_1']" + "test_imgs = test_data_file[\"arr_0\"]\n", + "test_labels = test_data_file[\"arr_1\"]" ] }, { @@ -305,23 +322,25 @@ "\n", " assert N_TRAIN % BATCH_SIZE == 0\n", " for step in range(N_TRAIN // BATCH_SIZE):\n", - " img_batch = train_imgs[step*BATCH_SIZE:(step+1)*BATCH_SIZE]\n", - " label_batch = train_labels[step*BATCH_SIZE:(step+1)*BATCH_SIZE]\n", - " \n", - " state, loss = train_step(state, img_batch, label_batch) \n", + " img_batch = train_imgs[step * BATCH_SIZE : (step + 1) * BATCH_SIZE]\n", + " label_batch = train_labels[step * BATCH_SIZE : (step + 1) * BATCH_SIZE]\n", + "\n", + " state, loss = train_step(state, img_batch, label_batch)\n", " training_losses[-1].append(loss)\n", - " #os.system('nvidia-smi')\n", - " print('.', end='')\n", + " # os.system('nvidia-smi')\n", + " print(\".\", end=\"\")\n", " print()\n", "\n", " epoch_loss = jnp.array(training_losses[-1]).mean()\n", " epoch_train_loss = calc_loss_batched(state, train_imgs, train_labels)\n", " epoch_test_loss = calc_loss_batched(state, test_imgs, test_labels)\n", - " #epoch_train_loss = 'not calculated'\n", - " #epoch_test_loss = 'not calculated'\n", - " print(f'epoch: {epoch}, average loss: {epoch_loss}, '\n", - " f'train loss: {epoch_train_loss}, '\n", - " f'test loss: {epoch_test_loss}')" + " # epoch_train_loss = 'not calculated'\n", + " # epoch_test_loss = 'not calculated'\n", + " print(\n", + " f\"epoch: {epoch}, average loss: {epoch_loss}, \"\n", + " f\"train loss: {epoch_train_loss}, \"\n", + " f\"test loss: {epoch_test_loss}\"\n", + " )" ] }, { @@ -343,15 +362,18 @@ "source": [ "IDX = 44\n", "__img = test_imgs[IDX].reshape(1, 100, 100, 1)\n", - "logits = state.apply_fn({'params': state.params}, __img)\n", + "logits = state.apply_fn({\"params\": state.params}, __img)\n", "print(logits)\n", "print(b.utils.normalize_log_scores(logits))\n", "print(test_labels[IDX])\n", - "print(optax.softmax_cross_entropy(\n", - " #logits= train_state.apply_fn({'params': train_state.params}, imgs),\n", + "print(\n", + " optax.softmax_cross_entropy(\n", + " # logits= train_state.apply_fn({'params': train_state.params}, imgs),\n", " logits=logits,\n", - " labels=test_labels[IDX]))\n", - "b.get_depth_image(test_imgs[IDX][...,-1])" + " labels=test_labels[IDX],\n", + " )\n", + ")\n", + "b.get_depth_image(test_imgs[IDX][..., -1])" ] }, { @@ -377,7 +399,7 @@ "metadata": {}, "outputs": [], "source": [ - "with open('param_file.pkl', 'wb') as params_file:\n", + "with open(\"param_file.pkl\", \"wb\") as params_file:\n", " pickle.dump(state.params, params_file)" ] }, @@ -389,11 +411,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=200.0, fy=200.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.0001, far=2.0\n", + " height=100, width=100, fx=200.0, fy=200.0, cx=50.0, cy=50.0, near=0.0001, far=2.0\n", ")\n", "\n", "b.setup_renderer(intrinsics)" @@ -406,10 +424,14 @@ "metadata": {}, "outputs": [], "source": [ - "fork_mesh_path = b.utils.get_assets_dir() + '/ycb_video_models/models/030_fork/nontextured.ply'\n", - "knife_mesh_path = b.utils.get_assets_dir() + '/ycb_video_models/models/032_knife/nontextured.ply'\n", - "box_mesh_path = b.utils.get_assets_dir() + '/bop/ycbv/models/obj_000002.ply'\n", - "table_mesh_path = b.utils.get_assets_dir() + '/sample_objs/cube.obj'\n", + "fork_mesh_path = (\n", + " b.utils.get_assets_dir() + \"/ycb_video_models/models/030_fork/nontextured.ply\"\n", + ")\n", + "knife_mesh_path = (\n", + " b.utils.get_assets_dir() + \"/ycb_video_models/models/032_knife/nontextured.ply\"\n", + ")\n", + "box_mesh_path = b.utils.get_assets_dir() + \"/bop/ycbv/models/obj_000002.ply\"\n", + "table_mesh_path = b.utils.get_assets_dir() + \"/sample_objs/cube.obj\"\n", "fork_scale = knife_scale = 1.0\n", "box_scale = 1e-3\n", "table_scale = 1e-6\n", @@ -449,24 +471,33 @@ "outputs": [], "source": [ "def fork_spoon_from_known_params(is_fork, shift):\n", - " indices = jax.lax.cond(is_fork,\n", - " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, FORK_IDX]),\n", - " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, KNIFE_IDX]))\n", + " indices = jax.lax.cond(\n", + " is_fork,\n", + " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, FORK_IDX]),\n", + " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, KNIFE_IDX]),\n", + " )\n", "\n", " box_dims = b.RENDERER.model_box_dims[indices]\n", " root_poses = jnp.array([table_pose, table_pose, table_pose])\n", " parents = jnp.array([-1, 0, 0])\n", - " contact_params = jnp.array([[0.0, 0.0, 0.0],\n", - " [*CHEESEITZ_BOX_CONTACT_PARAMS],\n", - " [shift*jnp.cos(jnp.pi/12), -0.05 + shift*jnp.sin(jnp.pi/12), 10*jnp.pi/12]])\n", + " contact_params = jnp.array(\n", + " [\n", + " [0.0, 0.0, 0.0],\n", + " [*CHEESEITZ_BOX_CONTACT_PARAMS],\n", + " [\n", + " shift * jnp.cos(jnp.pi / 12),\n", + " -0.05 + shift * jnp.sin(jnp.pi / 12),\n", + " 10 * jnp.pi / 12,\n", + " ],\n", + " ]\n", + " )\n", " faces_parents = jnp.array([0, 2, 2])\n", " faces_child = jnp.array([0, 3, 3])\n", " poses = b.scene_graph.poses_from_scene_graph(\n", - " root_poses, box_dims, parents, contact_params, faces_parents, faces_child)\n", + " root_poses, box_dims, parents, contact_params, faces_parents, faces_child\n", + " )\n", " camera_pose = jnp.eye(4)\n", - " rendered = b.RENDERER.render(\n", - " jnp.linalg.inv(camera_pose) @ poses , indices\n", - " )[...,:3]\n", + " rendered = b.RENDERER.render(jnp.linalg.inv(camera_pose) @ poses, indices)[..., :3]\n", " return (is_fork, rendered)" ] }, @@ -478,8 +509,16 @@ "outputs": [], "source": [ "ss = -0.1\n", - "b.viz.hstack_images([b.viz.scale_image(b.get_depth_image(fork_spoon_from_known_params(True, ss)[1][...,2]), 2),\n", - " b.viz.scale_image(b.get_depth_image(fork_spoon_from_known_params(False, ss)[1][...,2]), 2)])" + "b.viz.hstack_images(\n", + " [\n", + " b.viz.scale_image(\n", + " b.get_depth_image(fork_spoon_from_known_params(True, ss)[1][..., 2]), 2\n", + " ),\n", + " b.viz.scale_image(\n", + " b.get_depth_image(fork_spoon_from_known_params(False, ss)[1][..., 2]), 2\n", + " ),\n", + " ]\n", + ")" ] }, { @@ -489,7 +528,9 @@ "metadata": {}, "outputs": [], "source": [ - "make_onehot = lambda b: jax.lax.cond(b, lambda: jnp.array([0.0, 1.0]), lambda: jnp.array([1.0, 0.0]))" + "make_onehot = lambda b: jax.lax.cond(\n", + " b, lambda: jnp.array([0.0, 1.0]), lambda: jnp.array([1.0, 0.0])\n", + ")" ] }, { @@ -499,11 +540,14 @@ "metadata": {}, "outputs": [], "source": [ - "is_fork, __img, = fork_spoon_from_known_params(False, 0.0)\n", + "(\n", + " is_fork,\n", + " __img,\n", + ") = fork_spoon_from_known_params(False, 0.0)\n", "__img = __img[:, :, 2].reshape(1, 100, 100, 1)\n", - "logits = state.apply_fn({'params': state.params}, __img)\n", + "logits = state.apply_fn({\"params\": state.params}, __img)\n", "\n", - "print(f'true: {make_onehot(is_fork)} predicted: {jax.nn.softmax(logits)}')\n", + "print(f\"true: {make_onehot(is_fork)} predicted: {jax.nn.softmax(logits)}\")\n", "\n", "b.viz.scale_image(b.get_depth_image(__img.reshape(100, 100)), 2)" ] @@ -515,8 +559,8 @@ "metadata": {}, "outputs": [], "source": [ - "#img_batch, label_batch = make_batch(batch_keys)\n", - "#logits = state.apply_fn({'params': state.params}, img_batch)" + "# img_batch, label_batch = make_batch(batch_keys)\n", + "# logits = state.apply_fn({'params': state.params}, img_batch)" ] }, { @@ -526,7 +570,7 @@ "metadata": {}, "outputs": [], "source": [ - "#i = 33 #23 # 13" + "# i = 33 #23 # 13" ] }, { @@ -536,7 +580,7 @@ "metadata": {}, "outputs": [], "source": [ - "#b.viz.scale_image(b.get_depth_image(img_batch[i, :, :, 0]), 2)" + "# b.viz.scale_image(b.get_depth_image(img_batch[i, :, :, 0]), 2)" ] }, { @@ -546,7 +590,7 @@ "metadata": {}, "outputs": [], "source": [ - "#label_batch[i], jax.nn.softmax(logits[i])" + "# label_batch[i], jax.nn.softmax(logits[i])" ] }, { diff --git a/scripts/experiments/icra/fork_knife/fork-knife-datagen.ipynb b/scripts/experiments/icra/fork_knife/fork-knife-datagen.ipynb index 661d5a5b..426a4439 100644 --- a/scripts/experiments/icra/fork_knife/fork-knife-datagen.ipynb +++ b/scripts/experiments/icra/fork_knife/fork-knife-datagen.ipynb @@ -60,11 +60,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=200.0, fy=200.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.0001, far=2.0\n", + " height=100, width=100, fx=200.0, fy=200.0, cx=50.0, cy=50.0, near=0.0001, far=2.0\n", ")\n", "\n", "b.setup_renderer(intrinsics)\n", @@ -73,7 +69,7 @@ "# for idx in range(1,22):\n", "# mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", "# b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)" ] }, { @@ -84,28 +80,41 @@ "outputs": [], "source": [ "import trimesh\n", + "\n", "box_width = 0.02\n", "hammer_width = 0.05\n", "hand_length = 0.15\n", "b1 = trimesh.creation.box(\n", - " np.array(jnp.array([hand_length, box_width,box_width])),\n", - " np.array(b.transform_from_pos(jnp.array([0.0, 0.0, 0.0])))\n", + " np.array(jnp.array([hand_length, box_width, box_width])),\n", + " np.array(b.transform_from_pos(jnp.array([0.0, 0.0, 0.0]))),\n", ")\n", "b2 = trimesh.creation.box(\n", - " np.array(jnp.array([hammer_width,hammer_width, hammer_width])),\n", - " np.array(b.transform_from_pos(jnp.array([hand_length/2 - hammer_width/2, 0.0, 0.0])))\n", + " np.array(jnp.array([hammer_width, hammer_width, hammer_width])),\n", + " np.array(\n", + " b.transform_from_pos(jnp.array([hand_length / 2 - hammer_width / 2, 0.0, 0.0]))\n", + " ),\n", ")\n", "b3 = trimesh.creation.box(\n", - " np.array(jnp.array([hammer_width,hammer_width, hammer_width])),\n", - " np.array(b.transform_from_pos(jnp.array([-hand_length/2 + hammer_width/2, 0.0, 0.0, ])))\n", + " np.array(jnp.array([hammer_width, hammer_width, hammer_width])),\n", + " np.array(\n", + " b.transform_from_pos(\n", + " jnp.array(\n", + " [\n", + " -hand_length / 2 + hammer_width / 2,\n", + " 0.0,\n", + " 0.0,\n", + " ]\n", + " )\n", + " )\n", + " ),\n", ")\n", - "m1 = trimesh.util.concatenate([b1,b2])\n", - "m2 = trimesh.util.concatenate([b1,b2,b3])\n", + "m1 = trimesh.util.concatenate([b1, b2])\n", + "m2 = trimesh.util.concatenate([b1, b2, b3])\n", "b.show_trimesh(\"1\", m2)\n", "\n", "b.utils.mesh.export_mesh(m1, \"m1.obj\")\n", "b.utils.mesh.export_mesh(m2, \"m2.obj\")\n", - "table_mesh_path = b.utils.get_assets_dir() + '/sample_objs/cube.obj'\n", + "table_mesh_path = b.utils.get_assets_dir() + \"/sample_objs/cube.obj\"\n", "\n", "box_mesh = b.utils.make_cuboid_mesh(jnp.array([0.1, 0.1, 0.3]))\n", "b.RENDERER.add_mesh(m1)\n", @@ -161,24 +170,33 @@ "outputs": [], "source": [ "def fork_spoon_from_known_params(is_fork, shift):\n", - " indices = jax.lax.cond(is_fork,\n", - " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, FORK_IDX]),\n", - " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, KNIFE_IDX]))\n", + " indices = jax.lax.cond(\n", + " is_fork,\n", + " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, FORK_IDX]),\n", + " lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, KNIFE_IDX]),\n", + " )\n", "\n", " box_dims = b.RENDERER.model_box_dims[indices]\n", " root_poses = jnp.array([table_pose, table_pose, table_pose])\n", " parents = jnp.array([-1, 0, 0])\n", - " contact_params = jnp.array([[0.0, 0.0, 0.0],\n", - " [*CHEESEITZ_BOX_CONTACT_PARAMS],\n", - " [shift*jnp.cos(jnp.pi/12), -0.05 + shift*jnp.sin(jnp.pi/12), 10*jnp.pi/12]])\n", + " contact_params = jnp.array(\n", + " [\n", + " [0.0, 0.0, 0.0],\n", + " [*CHEESEITZ_BOX_CONTACT_PARAMS],\n", + " [\n", + " shift * jnp.cos(jnp.pi / 12),\n", + " -0.05 + shift * jnp.sin(jnp.pi / 12),\n", + " 10 * jnp.pi / 12,\n", + " ],\n", + " ]\n", + " )\n", " faces_parents = jnp.array([0, 2, 2])\n", " faces_child = jnp.array([0, 3, 3])\n", " poses = b.scene_graph.poses_from_scene_graph(\n", - " root_poses, box_dims, parents, contact_params, faces_parents, faces_child)\n", + " root_poses, box_dims, parents, contact_params, faces_parents, faces_child\n", + " )\n", " camera_pose = jnp.eye(4)\n", - " rendered = b.RENDERER.render(\n", - " jnp.linalg.inv(camera_pose) @ poses , indices\n", - " )[...,:3]\n", + " rendered = b.RENDERER.render(jnp.linalg.inv(camera_pose) @ poses, indices)[..., :3]\n", " return (is_fork, rendered)" ] }, @@ -190,8 +208,16 @@ "outputs": [], "source": [ "ss = 0.05\n", - "b.viz.hstack_images([b.viz.scale_image(b.get_depth_image(fork_spoon_from_known_params(True, ss)[1][...,2]), 2),\n", - " b.viz.scale_image(b.get_depth_image(fork_spoon_from_known_params(False, ss)[1][...,2]), 2)])" + "b.viz.hstack_images(\n", + " [\n", + " b.viz.scale_image(\n", + " b.get_depth_image(fork_spoon_from_known_params(True, ss)[1][..., 2]), 2\n", + " ),\n", + " b.viz.scale_image(\n", + " b.get_depth_image(fork_spoon_from_known_params(False, ss)[1][..., 2]), 2\n", + " ),\n", + " ]\n", + ")" ] }, { @@ -216,7 +242,9 @@ "metadata": {}, "outputs": [], "source": [ - "b.viz.scale_image(b.get_depth_image(fork_spoon_generator(jax.random.PRNGKey(0))[1][...,2]), 2)" + "b.viz.scale_image(\n", + " b.get_depth_image(fork_spoon_generator(jax.random.PRNGKey(0))[1][..., 2]), 2\n", + ")" ] }, { @@ -226,7 +254,9 @@ "metadata": {}, "outputs": [], "source": [ - "make_onehot = lambda b: jax.lax.cond(b, lambda: jnp.array([0.0, 1.0]), lambda: jnp.array([1.0, 0.0]))" + "make_onehot = lambda b: jax.lax.cond(\n", + " b, lambda: jnp.array([0.0, 1.0]), lambda: jnp.array([1.0, 0.0])\n", + ")" ] }, { @@ -238,19 +268,26 @@ "source": [ "@jax.jit\n", "def make_batch(batch_keys):\n", - " #is_forks, imgs = zip(*map(fork_spoon_generator, batch_keys))\n", - " #img_batch = jnp.concatenate([img[:,:,2].reshape(1, 100, 100, 1) for img in imgs], 0)\n", - " #label_batch = jnp.array([make_onehot(is_fork) for is_fork in is_forks])\n", - " #return img_batch, label_batch\n", + " # is_forks, imgs = zip(*map(fork_spoon_generator, batch_keys))\n", + " # img_batch = jnp.concatenate([img[:,:,2].reshape(1, 100, 100, 1) for img in imgs], 0)\n", + " # label_batch = jnp.array([make_onehot(is_fork) for is_fork in is_forks])\n", + " # return img_batch, label_batch\n", " batch_size = batch_keys.shape[0]\n", + "\n", " def loop_body(i, imgs_labels):\n", " imgs, labels = imgs_labels\n", " label, img = fork_spoon_generator(batch_keys[i])\n", - " return (imgs.at[i, :, :, 0].set(img[:, :, 2]),\n", - " labels.at[i, :].set(make_onehot(label)))\n", - " return jax.lax.fori_loop(0, batch_keys.shape[0],\n", - " loop_body,\n", - " (jnp.zeros((batch_size, 100, 100,1)), jnp.zeros((batch_size, 2))))" + " return (\n", + " imgs.at[i, :, :, 0].set(img[:, :, 2]),\n", + " labels.at[i, :].set(make_onehot(label)),\n", + " )\n", + "\n", + " return jax.lax.fori_loop(\n", + " 0,\n", + " batch_keys.shape[0],\n", + " loop_body,\n", + " (jnp.zeros((batch_size, 100, 100, 1)), jnp.zeros((batch_size, 2))),\n", + " )" ] }, { @@ -281,7 +318,7 @@ "metadata": {}, "outputs": [], "source": [ - "jnp.savez('train_data.npz', *train_data)" + "jnp.savez(\"train_data.npz\", *train_data)" ] }, { @@ -301,7 +338,7 @@ "metadata": {}, "outputs": [], "source": [ - "jnp.savez('test_data.npz', *test_data)" + "jnp.savez(\"test_data.npz\", *test_data)" ] }, { diff --git a/scripts/experiments/icra/mug/mug.ipynb b/scripts/experiments/icra/mug/mug.ipynb index ce26dcdc..b650a82c 100644 --- a/scripts/experiments/icra/mug/mug.ipynb +++ b/scripts/experiments/icra/mug/mug.ipynb @@ -64,22 +64,22 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=200.0, fy=200.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.0001, far=2.0\n", + " height=100, width=100, fx=200.0, fy=200.0, cx=50.0, cy=50.0, near=0.0001, far=2.0\n", ")\n", "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/10.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n", - "\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -91,7 +91,7 @@ "source": [ "table_pose = b.t3d.inverse_pose(\n", " b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.8, .15]),\n", + " jnp.array([0.0, 0.8, 0.15]),\n", " jnp.array([0.0, 0.0, 0.0]),\n", " jnp.array([0.0, 0.0, 1.0]),\n", " )\n", @@ -111,9 +111,15 @@ "num_position_grids = 51\n", "num_angle_grids = 51\n", "contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(\n", - " -width, -width, -ang,\n", - " width, width, ang,\n", - " num_position_grids,num_position_grids,num_angle_grids\n", + " -width,\n", + " -width,\n", + " -ang,\n", + " width,\n", + " width,\n", + " ang,\n", + " num_position_grids,\n", + " num_position_grids,\n", + " num_angle_grids,\n", ")" ] }, @@ -149,27 +155,45 @@ " # fig.suptitle(f\"Variance: {variance} Outlier Prob: {outlier_prob}\")\n", " widths = [1, 1]\n", " heights = [2]\n", - " spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", - " \n", + " spec = fig.add_gridspec(\n", + " ncols=2, nrows=1, width_ratios=widths, height_ratios=heights\n", + " )\n", + "\n", " ax = fig.add_subplot(spec[0, 0])\n", - " ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))\n", + " ax.imshow(jnp.array(b.get_depth_image(observation[..., 2], max=1.4)))\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.set_title(f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\")\n", + " ax.set_title(\n", + " f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\"\n", + " )\n", " # ax.set_title(f\"Observed Depth\")\n", - " \n", - " \n", + "\n", " ax = fig.add_subplot(spec[0, 1])\n", " ax.set_aspect(1.0)\n", - " circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle=\"--\", linewidth=0.5)\n", + " circ = plt.Circle(\n", + " (0, 0),\n", + " radius=1,\n", + " edgecolor=\"black\",\n", + " facecolor=\"None\",\n", + " linestyle=\"--\",\n", + " linewidth=0.5,\n", + " )\n", " ax.add_patch(circ)\n", " ax.set_xlim(-1.1, 1.1)\n", " ax.set_ylim(-1.1, 1.1)\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.scatter(-jnp.sin(sampled_contacts[:,2]),jnp.cos(sampled_contacts[:,2]), color='red',label=\"Posterior Samples\", alpha=0.5, s=30)\n", - " ax.scatter(-jnp.sin(gt_contact[2]),jnp.cos(gt_contact[2]), label=\"Actual\", alpha=0.9, s=25)\n", + " ax.scatter(\n", + " -jnp.sin(sampled_contacts[:, 2]),\n", + " jnp.cos(sampled_contacts[:, 2]),\n", + " color=\"red\",\n", + " label=\"Posterior Samples\",\n", + " alpha=0.5,\n", + " s=30,\n", + " )\n", + " ax.scatter(\n", + " -jnp.sin(gt_contact[2]), jnp.cos(gt_contact[2]), label=\"Actual\", alpha=0.9, s=25\n", + " )\n", " ax.set_title(\"Posterior on Orientation (top view)\")\n", " # ax.legend(fontsize=9)\n", " # plt.show()\n", @@ -187,10 +211,9 @@ " contact_param_grid = contact_param_deltas + trace_[f\"contact_params_1\"]\n", " scores = enumerators.enumerate_choices_get_scores(trace_, key, contact_param_grid)\n", " i = scores.argmax()\n", - " return enumerators.update_choices(\n", - " trace_, key,\n", - " contact_param_grid[i]\n", - " )\n", + " return enumerators.update_choices(trace_, key, contact_param_grid[i])\n", + "\n", + "\n", "c2f_contact_update_jit = jax.jit(c2f_contact_update)" ] }, @@ -213,16 +236,18 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.3, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),\n", - " (0.05, jnp.pi/3, (15,15,15)), (0.02, jnp.pi, (9,9,51)), (0.01, jnp.pi/5, (15,15,15)), (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))\n", + " (0.3, jnp.pi, (15, 15, 15)),\n", + " (0.2, jnp.pi, (15, 15, 15)),\n", + " (0.1, jnp.pi, (15, 15, 15)),\n", + " (0.05, jnp.pi / 3, (15, 15, 15)),\n", + " (0.02, jnp.pi, (9, 9, 51)),\n", + " (0.01, jnp.pi / 5, (15, 15, 15)),\n", + " (0.01, 0.0, (31, 31, 1)),\n", + " (0.05, 0.0, (31, 31, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]" ] }, @@ -233,7 +258,7 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.split(key,2)[0]" + "key = jax.random.split(key, 2)[0]" ] }, { @@ -263,24 +288,36 @@ ], "source": [ "low, high = jnp.array([-0.2, -0.2, -jnp.pi]), jnp.array([0.2, 0.2, jnp.pi])\n", - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"id_1\": jnp.int32(13),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " \"face_parent_1\": 3,\n", - " \"face_child_1\": 2,\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jax.random.uniform(key, shape=(3,),minval=low, maxval=high)\n", - "}), (\n", - " jnp.arange(2),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.5, -0.5, -2*jnp.pi]), jnp.array([0.5, 0.5, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, 1.0, intrinsics.fx)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"id_1\": jnp.int32(13),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " \"face_parent_1\": 3,\n", + " \"face_child_1\": 2,\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jax.random.uniform(\n", + " key, shape=(3,), minval=low, maxval=high\n", + " ),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(2),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.5, -0.5, -2 * jnp.pi]), jnp.array([0.5, 0.5, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " 1.0,\n", + " intrinsics.fx,\n", + " ),\n", ")\n", "gt_poses = b.get_poses(trace)\n", "gt_contact = trace[\"contact_params_1\"]\n", @@ -300,7 +337,9 @@ "path = []\n", "path.append(trace)\n", "for c2f_iter in range(len(contact_param_gridding_schedule)):\n", - " trace = c2f_contact_update_jit(trace, key, contact_param_gridding_schedule[c2f_iter])\n", + " trace = c2f_contact_update_jit(\n", + " trace, key, contact_param_gridding_schedule[c2f_iter]\n", + " )\n", " path.append(trace)\n", "print(trace[\"contact_params_1\"])\n", "b.viz_trace_rendered_observed(trace)" @@ -315,10 +354,13 @@ "source": [ "%%time\n", "contact_param_grid = trace[\"contact_params_1\"] + contact_param_deltas\n", - "weights = jnp.concatenate([\n", - " enumerators.enumerate_choices_get_scores(trace, key, cp)\n", - " for cp in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)" + "weights = jnp.concatenate(\n", + " [\n", + " enumerators.enumerate_choices_get_scores(trace, key, cp)\n", + " for cp in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")" ] }, { @@ -330,8 +372,10 @@ "source": [ "key2 = jax.random.split(key, 1)[0]\n", "normalized_weights = b.utils.normalize_log_scores(weights)\n", - "sampled_indices = jax.random.choice(key2,jnp.arange(normalized_weights.shape[0]), shape=(2000,), p=normalized_weights)\n", - "sampled_contact_params = contact_param_grid[sampled_indices]\n" + "sampled_indices = jax.random.choice(\n", + " key2, jnp.arange(normalized_weights.shape[0]), shape=(2000,), p=normalized_weights\n", + ")\n", + "sampled_contact_params = contact_param_grid[sampled_indices]" ] }, { @@ -342,21 +386,21 @@ "outputs": [], "source": [ "fig = plt.figure()\n", - "ax = fig.add_subplot(projection='3d')\n", + "ax = fig.add_subplot(projection=\"3d\")\n", "ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "# make the grid lines transparent\n", - "ax.xaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "ax.yaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "ax.zaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "u, v = np.mgrid[0:2*np.pi:21j, 0:np.pi:11j]\n", - "x = np.cos(u)*np.sin(v)\n", - "y = np.sin(u)*np.sin(v)\n", + "ax.xaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "ax.yaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "ax.zaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "u, v = np.mgrid[0 : 2 * np.pi : 21j, 0 : np.pi : 11j]\n", + "x = np.cos(u) * np.sin(v)\n", + "y = np.sin(u) * np.sin(v)\n", "z = np.cos(v)\n", - "ax.axes.set_xlim3d(-1.1, 1.1) \n", - "ax.axes.set_ylim3d(-1.1, 1.1) \n", - "ax.axes.set_zlim3d(-1.1, 1.1) \n", + "ax.axes.set_xlim3d(-1.1, 1.1)\n", + "ax.axes.set_ylim3d(-1.1, 1.1)\n", + "ax.axes.set_zlim3d(-1.1, 1.1)\n", "ax.set_aspect(\"equal\")\n", "ax.plot_wireframe(x, y, z, color=(0.0, 0.0, 0.0, 0.3), linewidths=0.5)\n", "\n", @@ -376,10 +420,16 @@ "z = 0.1\n", "# for i in np.arange(.1,1.01,.1):\n", "# ax.scatter(points[:,0], points[:,1],points[:,2], s=(40*i*(z*.9+.1))**2, color=(1,0,0,.3/i/10))\n", - "offset = jnp.pi/2\n", - "angle = jnp.pi/4 - jnp.pi/4 - jnp.pi/4 - jnp.pi/4\n", - "for i in np.arange(.1,1.01,.1):\n", - " ax.scatter(np.cos(angle + offset) * scaling, np.sin(angle + offset), 0.0, s=(40*i*(z*.9+.1))**2, color=(1,0,0,.3))\n", + "offset = jnp.pi / 2\n", + "angle = jnp.pi / 4 - jnp.pi / 4 - jnp.pi / 4 - jnp.pi / 4\n", + "for i in np.arange(0.1, 1.01, 0.1):\n", + " ax.scatter(\n", + " np.cos(angle + offset) * scaling,\n", + " np.sin(angle + offset),\n", + " 0.0,\n", + " s=(40 * i * (z * 0.9 + 0.1)) ** 2,\n", + " color=(1, 0, 0, 0.3),\n", + " )\n", "\n", "# plt.tight_layout()\n", "plt.savefig(\"sphere.pdf\")" @@ -395,13 +445,18 @@ "scaled_up_intrinsics = b.scale_camera_parameters(intrinsics, 4)\n", "\n", "b.setup_renderer(scaled_up_intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/10.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -421,14 +476,14 @@ "metadata": {}, "outputs": [], "source": [ - "depth = img[...,2]\n", + "depth = img[..., 2]\n", "minval = jnp.min(depth[depth > jnp.min(depth)])\n", "maxval = jnp.max(depth[depth < jnp.max(depth)])\n", "depth = depth.at[depth >= intrinsics.far].set(jnp.nan)\n", - "viz_img = np.array(b.viz.scale_image(b.get_depth_image(\n", - " depth, min=minval, max=maxval\n", - "), 3))\n", - "viz_img[viz_img.sum(-1) == 0,:] = 255.0\n", + "viz_img = np.array(\n", + " b.viz.scale_image(b.get_depth_image(depth, min=minval, max=maxval), 3)\n", + ")\n", + "viz_img[viz_img.sum(-1) == 0, :] = 255.0\n", "plt.imshow(viz_img)\n", "plt.xticks([])\n", "plt.yticks([])\n", @@ -471,7 +526,7 @@ "metadata": {}, "outputs": [], "source": [ - "jnp.linalg.norm(points,axis=-1)" + "jnp.linalg.norm(points, axis=-1)" ] }, { @@ -491,48 +546,57 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "fig = plt.figure(constrained_layout=True)\n", "\n", - "observation = trace[\"image\"]\n", + "observation = trace[\"image\"]\n", "\n", "# fig.suptitle(f\"Variance: {variance} Outlier Prob: {outlier_prob}\")\n", "widths = [1, 1]\n", "heights = [2]\n", - "spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", + "spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths, height_ratios=heights)\n", "\n", "ax = fig.add_subplot(spec[0, 0])\n", - "ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))\n", + "ax.imshow(jnp.array(b.get_depth_image(observation[..., 2], max=1.4)))\n", "ax.get_xaxis().set_visible(False)\n", "ax.get_yaxis().set_visible(False)\n", - "ax.set_title(f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\")\n", + "ax.set_title(\n", + " f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\"\n", + ")\n", "# ax.set_title(f\"Observed Depth\")\n", "\n", "\n", "dist = 0.6\n", "ax = fig.add_subplot(spec[0, 1])\n", "ax.quiver(\n", - " sampled_contact_params[:,0],sampled_contact_params[:,1],\n", - " -jnp.sin(sampled_contact_params[:,2]),jnp.cos(sampled_contact_params[:,2]),\n", + " sampled_contact_params[:, 0],\n", + " sampled_contact_params[:, 1],\n", + " -jnp.sin(sampled_contact_params[:, 2]),\n", + " jnp.cos(sampled_contact_params[:, 2]),\n", " scale=3.0,\n", - " alpha=0.1\n", - " )\n", + " alpha=0.1,\n", + ")\n", "\n", "ax.quiver(\n", - " gt_contact[0],gt_contact[1],\n", - " -jnp.sin(gt_contact[2]), jnp.cos(gt_contact[2]),\n", + " gt_contact[0],\n", + " gt_contact[1],\n", + " -jnp.sin(gt_contact[2]),\n", + " jnp.cos(gt_contact[2]),\n", " scale=5.0,\n", " alpha=0.8,\n", - " color=\"red\"\n", + " color=\"red\",\n", ")\n", "\n", "ax.set_aspect(1.0)\n", "from matplotlib.patches import Rectangle\n", - "ax.add_patch(Rectangle((gt_contact[0]-width, gt_contact[1]-width), 2*width, 2*width,fill=None))\n", "\n", - "ax.set_xlim(gt_contact[0]-width-0.02, gt_contact[0]+width+0.02)\n", - "ax.set_ylim(gt_contact[1]-width-0.02, gt_contact[1]+width+0.02)" + "ax.add_patch(\n", + " Rectangle(\n", + " (gt_contact[0] - width, gt_contact[1] - width), 2 * width, 2 * width, fill=None\n", + " )\n", + ")\n", + "\n", + "ax.set_xlim(gt_contact[0] - width - 0.02, gt_contact[0] + width + 0.02)\n", + "ax.set_ylim(gt_contact[1] - width - 0.02, gt_contact[1] + width + 0.02)" ] }, { @@ -547,8 +611,8 @@ "best_cell_idx = jnp.abs(contact_param_grid - gt_contact).sum(1).argmin()\n", "print(gt_contact, contact_param_grid[best_cell_idx])\n", "normalize_log_weights = w1eights - b.logsumexp(weights)\n", - "assert(weights.shape[0] == contact_param_grid.shape[0])\n", - "volume = (width / num_position_grids)**2 * (2*jnp.pi / num_angle_grids)\n", + "assert weights.shape[0] == contact_param_grid.shape[0]\n", + "volume = (width / num_position_grids) ** 2 * (2 * jnp.pi / num_angle_grids)\n", "log_likelihood = normalize_log_weights[best_cell_idx] - jnp.log(volume)\n", "print(log_likelihood)" ] diff --git a/scripts/experiments/icra/scene_parse/nice_top_figure.ipynb b/scripts/experiments/icra/scene_parse/nice_top_figure.ipynb index a53f9219..8550feac 100644 --- a/scripts/experiments/icra/scene_parse/nice_top_figure.ipynb +++ b/scripts/experiments/icra/scene_parse/nice_top_figure.ipynb @@ -21,11 +21,12 @@ "import glob\n", "import bayes3d.neural\n", "import pickle\n", + "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", + "# jax.config.update('jax_enable_checks', True)\n", "from bayes3d.neural.segmentation import carvekit_get_foreground_mask\n", "import genjax\n", - "from bayes3d.viz.open3dviz import Open3DVisualizer\n" + "from bayes3d.viz.open3dviz import Open3DVisualizer" ] }, { @@ -46,11 +47,7 @@ "outputs": [], "source": [ "base_intrinsics = b.Intrinsics(\n", - " height=50,\n", - " width=50,\n", - " fx=250.0, fy=250.0,\n", - " cx=25.0, cy=25.0,\n", - " near=0.01, far=20.0\n", + " height=50, width=50, fx=250.0, fy=250.0, cx=25.0, cy=25.0, near=0.01, far=20.0\n", ")\n", "intrinsics = b.scale_camera_parameters(base_intrinsics, 10)" ] @@ -93,11 +90,10 @@ "outputs": [], "source": [ "camera_pose = b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, -1.5, 1.50]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )\n", - "\n" + " jnp.array([0.0, -1.5, 1.50]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + ")" ] }, { @@ -146,7 +142,7 @@ "metadata": {}, "outputs": [], "source": [ - "ids = [12,13,10]\n", + "ids = [12, 13, 10]\n", "colors = b.distinct_colors(10)\n", "IDX = 0" ] @@ -158,8 +154,10 @@ "metadata": {}, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(ids[IDX]+1).rjust(6, '0') + \".ply\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(ids[IDX] + 1).rjust(6, \"0\") + \".ply\"\n", + ")\n", "mesh = trimesh.load(mesh_path)" ] }, @@ -172,7 +170,7 @@ "source": [ "resolution = 0.005\n", "v = b.utils.voxelize(vertices, resolution)\n", - "new_mesh = b.utils.make_voxel_mesh_from_point_cloud(v, resolution)\n" + "new_mesh = b.utils.make_voxel_mesh_from_point_cloud(v, resolution)" ] }, { @@ -194,7 +192,7 @@ "outputs": [], "source": [ "mesh = trimesh.load(\"toy_plane.ply\")\n", - "b.show_trimesh(\"1\", mesh, color=(0.7, 0.1, 0.1))\n" + "b.show_trimesh(\"1\", mesh, color=(0.7, 0.1, 0.1))" ] }, { diff --git a/scripts/experiments/icra/scene_parse/real_airplane.ipynb b/scripts/experiments/icra/scene_parse/real_airplane.ipynb index c12822ff..4c5ce296 100644 --- a/scripts/experiments/icra/scene_parse/real_airplane.ipynb +++ b/scripts/experiments/icra/scene_parse/real_airplane.ipynb @@ -21,8 +21,9 @@ "import glob\n", "import bayes3d.neural\n", "import pickle\n", + "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", + "# jax.config.update('jax_enable_checks', True)\n", "from bayes3d.neural.segmentation import carvekit_get_foreground_mask\n", "import genjax" ] @@ -55,9 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "paths = glob.glob(\n", - " \"panda_scans_v5//*.pkl\"\n", - ")\n", + "paths = glob.glob(\"panda_scans_v5//*.pkl\")\n", "all_data = pickle.load(open(paths[3], \"rb\"))\n", "IDX = 1\n", "data = all_data[IDX]" @@ -71,15 +70,17 @@ "outputs": [], "source": [ "print(data[\"camera_image\"].keys())\n", - "K = data[\"camera_image\"]['camera_matrix'][0]\n", - "rgb = data[\"camera_image\"]['rgbPixels']\n", - "depth = data[\"camera_image\"]['depthPixels']\n", - "camera_pose = data[\"camera_image\"]['camera_pose']\n", + "K = data[\"camera_image\"][\"camera_matrix\"][0]\n", + "rgb = data[\"camera_image\"][\"rgbPixels\"]\n", + "depth = data[\"camera_image\"][\"depthPixels\"]\n", + "camera_pose = data[\"camera_image\"][\"camera_pose\"]\n", "camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)\n", - "fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]\n", - "h,w = depth.shape\n", + "fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n", + "h, w = depth.shape\n", "near = 0.001\n", - "rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))\n", + "rgbd_original = b.RGBD(\n", + " rgb, depth, camera_pose, b.Intrinsics(h, w, fx, fy, cx, cy, 0.001, 10000.0)\n", + ")\n", "b.get_rgb_image(rgbd_original.rgb)" ] }, @@ -112,8 +113,12 @@ "outputs": [], "source": [ "plane_pose, plane_dims = b.utils.find_plane_and_dims(\n", - " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3), \n", - " ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=1#0.1\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + " ransac_threshold=0.001,\n", + " inlier_threshold=0.001,\n", + " segmentation_threshold=1, # 0.1\n", ")" ] }, @@ -125,7 +130,12 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -136,7 +146,7 @@ "metadata": {}, "outputs": [], "source": [ - "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original)*1.0, scaling_factor)" + "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original) * 1.0, scaling_factor)" ] }, { @@ -146,7 +156,9 @@ "metadata": {}, "outputs": [], "source": [ - "observed_depth = (rgbd_scaled_down.depth * mask) + (1.0 - mask)* rgbd_scaled_down.intrinsics.far" + "observed_depth = (rgbd_scaled_down.depth * mask) + (\n", + " 1.0 - mask\n", + ") * rgbd_scaled_down.intrinsics.far" ] }, { @@ -157,7 +169,9 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1, 3)\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -171,7 +185,10 @@ "b.setup_renderer(rgbd_scaled_down.intrinsics)\n", "b.RENDERER.add_mesh_from_file(\"toy_plane.ply\")\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -192,17 +209,19 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi, (11,11,11)),\n", - " (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.05, 0.0, (21,21,1))\n", + " (0.3, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi, (11, 11, 11)),\n", + " (0.1, jnp.pi, (11, 11, 11)),\n", + " (0.05, jnp.pi / 3, (11, 11, 11)),\n", + " (0.02, jnp.pi, (5, 5, 51)),\n", + " (0.01, jnp.pi / 5, (11, 11, 11)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.05, 0.0, (21, 21, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", - "]\n" + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", + "]" ] }, { @@ -212,25 +231,35 @@ "metadata": {}, "outputs": [], "source": [ - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(1),\n", - " \"id_1\": jnp.int32(0),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": plane_pose,\n", - " \"face_parent_1\": 2,\n", - " \"face_child_1\": 3,\n", - " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0])\n", - "}), (\n", - " jnp.arange(2),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, 1.0, 1.0)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(1),\n", + " \"id_1\": jnp.int32(0),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": plane_pose,\n", + " \"face_parent_1\": 2,\n", + " \"face_child_1\": 3,\n", + " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0]),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(2),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " 1.0,\n", + " 1.0,\n", + " ),\n", ")\n", "b.viz_trace_meshcat(trace)\n", "print(trace.get_score())" @@ -247,11 +276,8 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[f\"contact_params_1\"]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " b.viz_trace_meshcat(trace)" ] }, diff --git a/scripts/experiments/icra/scene_parse/real_mug.ipynb b/scripts/experiments/icra/scene_parse/real_mug.ipynb index 635f1fcd..4f2597c7 100644 --- a/scripts/experiments/icra/scene_parse/real_mug.ipynb +++ b/scripts/experiments/icra/scene_parse/real_mug.ipynb @@ -21,8 +21,9 @@ "import glob\n", "import bayes3d.neural\n", "import pickle\n", + "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", + "# jax.config.update('jax_enable_checks', True)\n", "from bayes3d.neural.segmentation import carvekit_get_foreground_mask\n", "import genjax" ] @@ -55,9 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "paths = glob.glob(\n", - " \"data/*.pkl\"\n", - ")\n", + "paths = glob.glob(\"data/*.pkl\")\n", "paths\n", "all_data = pickle.load(open(paths[8], \"rb\"))\n", "IDX = 1\n", @@ -72,15 +71,17 @@ "outputs": [], "source": [ "print(data[\"camera_image\"].keys())\n", - "K = data[\"camera_image\"]['camera_matrix'][0]\n", - "rgb = data[\"camera_image\"]['rgbPixels']\n", - "depth = data[\"camera_image\"]['depthPixels']\n", - "camera_pose = data[\"camera_image\"]['camera_pose']\n", + "K = data[\"camera_image\"][\"camera_matrix\"][0]\n", + "rgb = data[\"camera_image\"][\"rgbPixels\"]\n", + "depth = data[\"camera_image\"][\"depthPixels\"]\n", + "camera_pose = data[\"camera_image\"][\"camera_pose\"]\n", "camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)\n", - "fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]\n", - "h,w = depth.shape\n", + "fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n", + "h, w = depth.shape\n", "near = 0.001\n", - "rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))\n" + "rgbd_original = b.RGBD(\n", + " rgb, depth, camera_pose, b.Intrinsics(h, w, fx, fy, cx, cy, 0.001, 10000.0)\n", + ")" ] }, { @@ -122,8 +123,12 @@ "outputs": [], "source": [ "plane_pose, plane_dims = b.utils.find_plane_and_dims(\n", - " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3), \n", - " ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=1#0.1\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + " ransac_threshold=0.001,\n", + " inlier_threshold=0.001,\n", + " segmentation_threshold=1, # 0.1\n", ")" ] }, @@ -135,7 +140,12 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -146,7 +156,7 @@ "metadata": {}, "outputs": [], "source": [ - "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original)*1.0, scaling_factor)" + "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original) * 1.0, scaling_factor)" ] }, { @@ -156,7 +166,9 @@ "metadata": {}, "outputs": [], "source": [ - "observed_depth = (rgbd_scaled_down.depth * mask) + (1.0 - mask)* rgbd_scaled_down.intrinsics.far" + "observed_depth = (rgbd_scaled_down.depth * mask) + (\n", + " 1.0 - mask\n", + ") * rgbd_scaled_down.intrinsics.far" ] }, { @@ -167,7 +179,9 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1, 3)\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -179,13 +193,18 @@ "outputs": [], "source": [ "b.setup_renderer(rgbd_scaled_down.intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -196,17 +215,19 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi, (11,11,11)),\n", - " (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.05, 0.0, (21,21,1))\n", + " (0.3, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi, (11, 11, 11)),\n", + " (0.1, jnp.pi, (11, 11, 11)),\n", + " (0.05, jnp.pi / 3, (11, 11, 11)),\n", + " (0.02, jnp.pi, (5, 5, 51)),\n", + " (0.01, jnp.pi / 5, (11, 11, 11)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.05, 0.0, (21, 21, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", - "]\n" + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", + "]" ] }, { @@ -216,25 +237,35 @@ "metadata": {}, "outputs": [], "source": [ - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"id_1\": jnp.int32(13),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": plane_pose,\n", - " \"face_parent_1\": 2,\n", - " \"face_child_1\": 3,\n", - " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0])\n", - "}), (\n", - " jnp.arange(2),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, 1.0, 1.0)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"id_1\": jnp.int32(13),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": plane_pose,\n", + " \"face_parent_1\": 2,\n", + " \"face_child_1\": 3,\n", + " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0]),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(2),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " 1.0,\n", + " 1.0,\n", + " ),\n", ")\n", "b.viz_trace_meshcat(trace)\n", "print(trace.get_score())" @@ -259,11 +290,8 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[f\"contact_params_1\"]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " b.viz_trace_meshcat(trace)" ] } diff --git a/scripts/experiments/icra/scene_parse/real_multiobject-Copy1.ipynb b/scripts/experiments/icra/scene_parse/real_multiobject-Copy1.ipynb index 12e85bfb..0561de52 100644 --- a/scripts/experiments/icra/scene_parse/real_multiobject-Copy1.ipynb +++ b/scripts/experiments/icra/scene_parse/real_multiobject-Copy1.ipynb @@ -21,8 +21,9 @@ "import glob\n", "import bayes3d.neural\n", "import pickle\n", + "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", + "# jax.config.update('jax_enable_checks', True)\n", "from bayes3d.neural.segmentation import carvekit_get_foreground_mask\n", "import genjax" ] @@ -55,9 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "paths = glob.glob(\n", - " \"panda_scans_v6/*.pkl\"\n", - ")\n", + "paths = glob.glob(\"panda_scans_v6/*.pkl\")\n", "all_data = pickle.load(open(paths[0], \"rb\"))\n", "IDX = 1\n", "data = all_data[IDX]" @@ -71,15 +70,17 @@ "outputs": [], "source": [ "print(data[\"camera_image\"].keys())\n", - "K = data[\"camera_image\"]['camera_matrix'][0]\n", - "rgb = data[\"camera_image\"]['rgbPixels']\n", - "depth = data[\"camera_image\"]['depthPixels']\n", - "camera_pose = data[\"camera_image\"]['camera_pose']\n", + "K = data[\"camera_image\"][\"camera_matrix\"][0]\n", + "rgb = data[\"camera_image\"][\"rgbPixels\"]\n", + "depth = data[\"camera_image\"][\"depthPixels\"]\n", + "camera_pose = data[\"camera_image\"][\"camera_pose\"]\n", "camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)\n", - "fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]\n", - "h,w = depth.shape\n", + "fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n", + "h, w = depth.shape\n", "near = 0.001\n", - "rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))\n", + "rgbd_original = b.RGBD(\n", + " rgb, depth, camera_pose, b.Intrinsics(h, w, fx, fy, cx, cy, 0.001, 10000.0)\n", + ")\n", "b.get_rgb_image(rgbd_original.rgb)" ] }, @@ -90,7 +91,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_depth_image(rgbd_original.depth,max=1.5)" + "b.get_depth_image(rgbd_original.depth, max=1.5)" ] }, { @@ -101,7 +102,7 @@ "outputs": [], "source": [ "scaling_factor = 0.23\n", - "rgbd_scaled_down = b.RGBD.scale_rgbd(rgbd_original, scaling_factor)\n" + "rgbd_scaled_down = b.RGBD.scale_rgbd(rgbd_original, scaling_factor)" ] }, { @@ -120,8 +121,12 @@ "outputs": [], "source": [ "plane_pose, plane_dims = b.utils.find_plane_and_dims(\n", - " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3), \n", - " ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=0.1\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + " ransac_threshold=0.001,\n", + " inlier_threshold=0.001,\n", + " segmentation_threshold=0.1,\n", ")" ] }, @@ -132,7 +137,9 @@ "metadata": {}, "outputs": [], "source": [ - "plane_pose = plane_pose @ b.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi)" + "plane_pose = plane_pose @ b.transform_from_axis_angle(\n", + " jnp.array([1.0, 0.0, 0.0]), jnp.pi\n", + ")" ] }, { @@ -143,7 +150,12 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -154,7 +166,7 @@ "metadata": {}, "outputs": [], "source": [ - "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original)*1.0, scaling_factor)" + "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original) * 1.0, scaling_factor)" ] }, { @@ -164,7 +176,9 @@ "metadata": {}, "outputs": [], "source": [ - "observed_depth = (rgbd_scaled_down.depth * mask) + (1.0 - mask)* rgbd_scaled_down.intrinsics.far" + "observed_depth = (rgbd_scaled_down.depth * mask) + (\n", + " 1.0 - mask\n", + ") * rgbd_scaled_down.intrinsics.far" ] }, { @@ -175,7 +189,9 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1, 3)\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -208,13 +224,16 @@ "source": [ "b.setup_renderer(rgbd_scaled_down.intrinsics)\n", "b.RENDERER.add_mesh_from_file(\"toy_plane.ply\")\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(13+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(10+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(13 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(10 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -236,18 +255,21 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.4, jnp.pi, (11,11,11)), (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)),\n", - " (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.01, 0.0, (21,21,1)),\n", - " (0.01, jnp.pi/10, (5,5,21)),(0.01, jnp.pi/20, (5,5,21))\n", + " (0.4, jnp.pi, (11, 11, 11)),\n", + " (0.3, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi, (11, 11, 11)),\n", + " (0.05, jnp.pi / 3, (11, 11, 11)),\n", + " (0.02, jnp.pi, (5, 5, 51)),\n", + " (0.01, jnp.pi / 5, (11, 11, 11)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.01, jnp.pi / 10, (5, 5, 21)),\n", + " (0.01, jnp.pi / 20, (5, 5, 21)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", - "]\n" + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", + "]" ] }, { @@ -257,24 +279,34 @@ "metadata": {}, "outputs": [], "source": [ - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(3),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": plane_pose,\n", - " \"face_parent_1\": 2,\n", - " \"face_child_1\": 3,\n", - " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", - " \"variance\": 0.001,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0])\n", - "}), (\n", - " jnp.arange(1),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.6, -0.6, -4*jnp.pi]), jnp.array([0.6, 0.6, 4*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, 1.0, 1.0)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(3),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": plane_pose,\n", + " \"face_parent_1\": 2,\n", + " \"face_child_1\": 3,\n", + " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", + " \"variance\": 0.001,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0]),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(1),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.6, -0.6, -4 * jnp.pi]), jnp.array([0.6, 0.6, 4 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " 1.0,\n", + " 1.0,\n", + " ),\n", ")\n", "b.viz_trace_meshcat(trace)\n", "print(trace.get_score())" @@ -287,7 +319,7 @@ "metadata": {}, "outputs": [], "source": [ - "object_number_to_id = [None, 2, 1,0]\n", + "object_number_to_id = [None, 2, 1, 0]\n", "# object_number_to_id = [None, 1]" ] }, @@ -309,7 +341,7 @@ "outputs": [], "source": [ "address = f\"contact_params_{OBJECT_NUMBER}\"\n", - "trace = b.add_object_jit(trace, key, object_number_to_id[OBJECT_NUMBER], 0, 2,3)\n", + "trace = b.add_object_jit(trace, key, object_number_to_id[OBJECT_NUMBER], 0, 2, 3)\n", "enumerators = b.make_enumerator([address])\n", "b.viz_trace_meshcat(trace)" ] @@ -326,14 +358,11 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[address]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " traces.append(trace)\n", " b.viz_trace_meshcat(trace)\n", - "b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0)" + "b.get_depth_image(b.get_rendered_image(trace)[..., 2], max=1.0)" ] }, { @@ -353,18 +382,26 @@ "metadata": {}, "outputs": [], "source": [ - "depth_viz = b.viz.resize_image(b.get_depth_image(rgbd_original.depth,max=1.5), b.RENDERER.intrinsics.height, b.RENDERER.intrinsics.width)\n", - "depth_reconstruction_viz = b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0)\n", - "seg_viz = b.get_depth_image(b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:,:,3], max=5.0)\n", - "rgb_viz = b.resize_image(b.get_rgb_image(rgbd_original.rgb), b.RENDERER.intrinsics.height, b.RENDERER.intrinsics.width)\n", - "overlay_viz = b.overlay_image(b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height,rgb_viz.width), rgb_viz)\n", - "b.vstack_images([\n", - " depth_viz,\n", - " depth_reconstruction_viz,\n", - " seg_viz,\n", - " overlay_viz\n", - "])\n", - " " + "depth_viz = b.viz.resize_image(\n", + " b.get_depth_image(rgbd_original.depth, max=1.5),\n", + " b.RENDERER.intrinsics.height,\n", + " b.RENDERER.intrinsics.width,\n", + ")\n", + "depth_reconstruction_viz = b.get_depth_image(\n", + " b.get_rendered_image(trace)[..., 2], max=1.0\n", + ")\n", + "seg_viz = b.get_depth_image(\n", + " b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:, :, 3], max=5.0\n", + ")\n", + "rgb_viz = b.resize_image(\n", + " b.get_rgb_image(rgbd_original.rgb),\n", + " b.RENDERER.intrinsics.height,\n", + " b.RENDERER.intrinsics.width,\n", + ")\n", + "overlay_viz = b.overlay_image(\n", + " b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height, rgb_viz.width), rgb_viz\n", + ")\n", + "b.vstack_images([depth_viz, depth_reconstruction_viz, seg_viz, overlay_viz])" ] }, { @@ -384,7 +421,9 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_depth_image(b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:,:,3], max=5.0)" + "b.get_depth_image(\n", + " b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:, :, 3], max=5.0\n", + ")" ] }, { @@ -404,7 +443,9 @@ "metadata": {}, "outputs": [], "source": [ - "depth_reconstruction_viz = b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0)\n", + "depth_reconstruction_viz = b.get_depth_image(\n", + " b.get_rendered_image(trace)[..., 2], max=1.0\n", + ")\n", "rgb_viz = b.get_rgb_image(rgbd_original.rgb)" ] }, @@ -415,7 +456,9 @@ "metadata": {}, "outputs": [], "source": [ - "b.overlay_image(b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height,rgb_viz.width), rgb_viz)" + "b.overlay_image(\n", + " b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height, rgb_viz.width), rgb_viz\n", + ")" ] }, { @@ -433,7 +476,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.viz.scale_image(depth_reconstruction_viz, 1/scaling_factor).size" + "b.viz.scale_image(depth_reconstruction_viz, 1 / scaling_factor).size" ] }, { @@ -475,7 +518,7 @@ "source": [ "idx = 0\n", "contact_param_deltas = contact_param_gridding_schedule[idx]\n", - "contact_param_grid = contact_param_deltas + trace[address]\n" + "contact_param_grid = contact_param_deltas + trace[address]" ] }, { @@ -485,7 +528,7 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.split(key,2)[0]" + "key = jax.random.split(key, 2)[0]" ] }, { @@ -497,7 +540,9 @@ "source": [ "contact_param_deltas = contact_param_gridding_schedule[idx]\n", "contact_param_grid = contact_param_deltas + trace[address]\n", - "indices_in_contact_param_grid = jax.random.choice(key, contact_param_grid.shape[0], shape=(50,))" + "indices_in_contact_param_grid = jax.random.choice(\n", + " key, contact_param_grid.shape[0], shape=(50,)\n", + ")" ] }, { @@ -509,11 +554,8 @@ "source": [ "images = []\n", "for i in indices_in_contact_param_grid:\n", - " trace_ = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", - " images.append(b.get_depth_image(b.get_rendered_image(trace_)[...,2], max=1.5))" + " trace_ = enumerators.update_choices(trace, key, contact_param_grid[i])\n", + " images.append(b.get_depth_image(b.get_rendered_image(trace_)[..., 2], max=1.5))" ] }, { @@ -523,7 +565,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.hvstack_images(images, 10,5)" + "b.hvstack_images(images, 10, 5)" ] }, { @@ -538,11 +580,8 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[address]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " traces.append(trace)\n", " b.viz_trace_meshcat(trace)" ] @@ -554,7 +593,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.viz.scale_image(b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0),5)" + "b.viz.scale_image(b.get_depth_image(b.get_rendered_image(trace)[..., 2], max=1.0), 5)" ] }, { @@ -574,7 +613,12 @@ "metadata": {}, "outputs": [], "source": [ - "b.viz.scale_image(b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(traces[0])[...,2], 1.0)),5)" + "b.viz.scale_image(\n", + " b.get_rgb_image(\n", + " get_depth_image_alternate(b.get_rendered_image(traces[0])[..., 2], 1.0)\n", + " ),\n", + " 5,\n", + ")" ] }, { @@ -586,13 +630,16 @@ "source": [ "b.setup_renderer(rgbd_original.intrinsics)\n", "b.RENDERER.add_mesh_from_file(\"toy_plane.ply\")\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(13+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(10+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(13 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(10 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -620,7 +667,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_depth_image(img[:,:,2],max=1.5)" + "b.get_depth_image(img[:, :, 2], max=1.5)" ] }, { @@ -666,16 +713,20 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.split(key,1)[0]\n", - "new_object_idx = jax.random.choice(key,3)\n", + "key = jax.random.split(key, 1)[0]\n", + "new_object_idx = jax.random.choice(key, 3)\n", "contact_param_grid = contact_param_gridding_schedule[0] + jnp.zeros(3)\n", - "key = jax.random.split(key,1)[0]\n", - "contact_param_random = contact_param_grid[jax.random.choice(key, contact_param_grid.shape[0]),:]\n", + "key = jax.random.split(key, 1)[0]\n", + "contact_param_random = contact_param_grid[\n", + " jax.random.choice(key, contact_param_grid.shape[0]), :\n", + "]\n", "print(contact_param_random)\n", "trace_ = b.update_address(trace, key, address, contact_param_random)\n", "trace_ = b.update_address(trace_, key, f\"id_{OBJECT_NUMBER}\", new_object_idx)\n", - "counter +=1\n", - "b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(trace_)[...,2], 1.0)).save(f\"{counter}.png\")\n" + "counter += 1\n", + "b.get_rgb_image(\n", + " get_depth_image_alternate(b.get_rendered_image(trace_)[..., 2], 1.0)\n", + ").save(f\"{counter}.png\")" ] }, { @@ -699,11 +750,8 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[f\"contact_params_1\"]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " b.viz_trace_meshcat(trace)" ] }, @@ -722,7 +770,12 @@ "metadata": {}, "outputs": [], "source": [ - "b.viz.scale_image(b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(trace)[...,2], 1.0)),5)" + "b.viz.scale_image(\n", + " b.get_rgb_image(\n", + " get_depth_image_alternate(b.get_rendered_image(trace)[..., 2], 1.0)\n", + " ),\n", + " 5,\n", + ")" ] }, { @@ -733,7 +786,7 @@ "outputs": [], "source": [ "enumerators = b.make_enumerator([f\"contact_params_2\"])\n", - "trace = b.add_object_jit(trace, key, 1, 0, 2,3)\n", + "trace = b.add_object_jit(trace, key, 1, 0, 2, 3)\n", "b.viz_trace_meshcat(trace)" ] }, @@ -748,11 +801,8 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[f\"contact_params_2\"]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " b.viz_trace_meshcat(trace)" ] }, @@ -764,7 +814,7 @@ "outputs": [], "source": [ "enumerators = b.make_enumerator([f\"contact_params_3\"])\n", - "trace = b.add_object_jit(trace, key, 0, 0, 2,3)\n", + "trace = b.add_object_jit(trace, key, 0, 0, 2, 3)\n", "b.viz_trace_meshcat(trace)" ] }, @@ -779,11 +829,8 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[f\"contact_params_3\"]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " b.viz_trace_meshcat(trace)" ] }, @@ -800,10 +847,8 @@ " if maxval is None:\n", " maxval = jnp.max(depth[depth < jnp.max(depth)])\n", " depth = depth.at[depth >= far].set(jnp.nan)\n", - " viz_img = np.array(b.get_depth_image(\n", - " depth, min=minval, max=maxval\n", - " ))\n", - " viz_img[viz_img.sum(-1) == 0,:] = 255.0\n", + " viz_img = np.array(b.get_depth_image(depth, min=minval, max=maxval))\n", + " viz_img[viz_img.sum(-1) == 0, :] = 255.0\n", " return viz_img" ] }, @@ -814,7 +859,12 @@ "metadata": {}, "outputs": [], "source": [ - "b.viz.scale_image(b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(trace)[...,2], 1.0)),5)" + "b.viz.scale_image(\n", + " b.get_rgb_image(\n", + " get_depth_image_alternate(b.get_rendered_image(trace)[..., 2], 1.0)\n", + " ),\n", + " 5,\n", + ")" ] }, { @@ -824,7 +874,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_rgb_image(get_depth_image_alternate(jnp.array(rgbd_original.depth),1.0))" + "b.get_rgb_image(get_depth_image_alternate(jnp.array(rgbd_original.depth), 1.0))" ] }, { diff --git a/scripts/experiments/icra/scene_parse/real_multiobject.ipynb b/scripts/experiments/icra/scene_parse/real_multiobject.ipynb index 9322af6a..d1534713 100644 --- a/scripts/experiments/icra/scene_parse/real_multiobject.ipynb +++ b/scripts/experiments/icra/scene_parse/real_multiobject.ipynb @@ -21,8 +21,9 @@ "import glob\n", "import bayes3d.neural\n", "import pickle\n", + "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", + "# jax.config.update('jax_enable_checks', True)\n", "from bayes3d.neural.segmentation import carvekit_get_foreground_mask\n", "import genjax" ] @@ -55,9 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "paths = glob.glob(\n", - " \"panda_scans_v6/*.pkl\"\n", - ")\n", + "paths = glob.glob(\"panda_scans_v6/*.pkl\")\n", "all_data = pickle.load(open(paths[0], \"rb\"))\n", "IDX = 0\n", "data = all_data[IDX]" @@ -71,15 +70,17 @@ "outputs": [], "source": [ "print(data[\"camera_image\"].keys())\n", - "K = data[\"camera_image\"]['camera_matrix'][0]\n", - "rgb = data[\"camera_image\"]['rgbPixels']\n", - "depth = data[\"camera_image\"]['depthPixels']\n", - "camera_pose = data[\"camera_image\"]['camera_pose']\n", + "K = data[\"camera_image\"][\"camera_matrix\"][0]\n", + "rgb = data[\"camera_image\"][\"rgbPixels\"]\n", + "depth = data[\"camera_image\"][\"depthPixels\"]\n", + "camera_pose = data[\"camera_image\"][\"camera_pose\"]\n", "camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)\n", - "fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]\n", - "h,w = depth.shape\n", + "fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n", + "h, w = depth.shape\n", "near = 0.001\n", - "rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))\n", + "rgbd_original = b.RGBD(\n", + " rgb, depth, camera_pose, b.Intrinsics(h, w, fx, fy, cx, cy, 0.001, 10000.0)\n", + ")\n", "b.get_rgb_image(rgbd_original.rgb)" ] }, @@ -90,7 +91,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.get_depth_image(rgbd_original.depth,max=1.5)" + "b.get_depth_image(rgbd_original.depth, max=1.5)" ] }, { @@ -112,8 +113,12 @@ "outputs": [], "source": [ "plane_pose, plane_dims = b.utils.find_plane_and_dims(\n", - " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3), \n", - " ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=0.1\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + " ransac_threshold=0.001,\n", + " inlier_threshold=0.001,\n", + " segmentation_threshold=0.1,\n", ")" ] }, @@ -124,7 +129,9 @@ "metadata": {}, "outputs": [], "source": [ - "plane_pose = plane_pose @ b.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi)" + "plane_pose = plane_pose @ b.transform_from_axis_angle(\n", + " jnp.array([1.0, 0.0, 0.0]), jnp.pi\n", + ")" ] }, { @@ -135,7 +142,12 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -146,7 +158,7 @@ "metadata": {}, "outputs": [], "source": [ - "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original)*1.0, scaling_factor)" + "mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original) * 1.0, scaling_factor)" ] }, { @@ -156,7 +168,9 @@ "metadata": {}, "outputs": [], "source": [ - "observed_depth = (rgbd_scaled_down.depth * mask) + (1.0 - mask)* rgbd_scaled_down.intrinsics.far" + "observed_depth = (rgbd_scaled_down.depth * mask) + (\n", + " 1.0 - mask\n", + ") * rgbd_scaled_down.intrinsics.far" ] }, { @@ -167,7 +181,9 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1, 3)\n", + ")\n", "b.show_pose(\"table\", plane_pose)" ] }, @@ -190,13 +206,16 @@ "source": [ "b.setup_renderer(rgbd_scaled_down.intrinsics)\n", "b.RENDERER.add_mesh_from_file(\"toy_plane.ply\")\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(13+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(10+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(13 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(10 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -218,17 +237,19 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi, (11,11,11)),\n", - " (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.05, 0.0, (21,21,1))\n", + " (0.3, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi, (11, 11, 11)),\n", + " (0.1, jnp.pi, (11, 11, 11)),\n", + " (0.05, jnp.pi / 3, (11, 11, 11)),\n", + " (0.02, jnp.pi, (5, 5, 51)),\n", + " (0.01, jnp.pi / 5, (11, 11, 11)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.05, 0.0, (21, 21, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", - "]\n" + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", + "]" ] }, { @@ -238,24 +259,34 @@ "metadata": {}, "outputs": [], "source": [ - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(3),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": plane_pose,\n", - " \"face_parent_1\": 2,\n", - " \"face_child_1\": 3,\n", - " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0])\n", - "}), (\n", - " jnp.arange(1),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, 1.0, 1.0)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(3),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": plane_pose,\n", + " \"face_parent_1\": 2,\n", + " \"face_child_1\": 3,\n", + " \"image\": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0]),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(1),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " 1.0,\n", + " 1.0,\n", + " ),\n", ")\n", "b.viz_trace_meshcat(trace)\n", "print(trace.get_score())" @@ -280,7 +311,7 @@ "metadata": {}, "outputs": [], "source": [ - "trace = b.add_object_jit(trace, key, 1, 0, 2,3)\n", + "trace = b.add_object_jit(trace, key, 1, 0, 2, 3)\n", "b.viz_trace_meshcat(trace)" ] }, @@ -295,11 +326,8 @@ " contact_param_deltas = contact_param_gridding_schedule[idx]\n", " contact_param_grid = contact_param_deltas + trace[address]\n", " scores = enumerators.enumerate_choices_get_scores(trace, key, contact_param_grid)\n", - " i= jnp.unravel_index(scores.argmax(), scores.shape)\n", - " trace = enumerators.update_choices(\n", - " trace, key,\n", - " contact_param_grid[i]\n", - " )\n", + " i = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " trace = enumerators.update_choices(trace, key, contact_param_grid[i])\n", " b.viz_trace_meshcat(trace)" ] }, @@ -310,18 +338,26 @@ "metadata": {}, "outputs": [], "source": [ - "depth_viz = b.viz.resize_image(b.get_depth_image(rgbd_original.depth,max=1.5), b.RENDERER.intrinsics.height, b.RENDERER.intrinsics.width)\n", - "depth_reconstruction_viz = b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0)\n", - "seg_viz = b.get_depth_image(b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:,:,3], max=5.0)\n", - "rgb_viz = b.resize_image(b.get_rgb_image(rgbd_original.rgb), b.RENDERER.intrinsics.height, b.RENDERER.intrinsics.width)\n", - "overlay_viz = b.overlay_image(b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height,rgb_viz.width), rgb_viz)\n", - "b.vstack_images([\n", - " depth_viz,\n", - " depth_reconstruction_viz,\n", - " seg_viz,\n", - " overlay_viz\n", - "])\n", - " " + "depth_viz = b.viz.resize_image(\n", + " b.get_depth_image(rgbd_original.depth, max=1.5),\n", + " b.RENDERER.intrinsics.height,\n", + " b.RENDERER.intrinsics.width,\n", + ")\n", + "depth_reconstruction_viz = b.get_depth_image(\n", + " b.get_rendered_image(trace)[..., 2], max=1.0\n", + ")\n", + "seg_viz = b.get_depth_image(\n", + " b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:, :, 3], max=5.0\n", + ")\n", + "rgb_viz = b.resize_image(\n", + " b.get_rgb_image(rgbd_original.rgb),\n", + " b.RENDERER.intrinsics.height,\n", + " b.RENDERER.intrinsics.width,\n", + ")\n", + "overlay_viz = b.overlay_image(\n", + " b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height, rgb_viz.width), rgb_viz\n", + ")\n", + "b.vstack_images([depth_viz, depth_reconstruction_viz, seg_viz, overlay_viz])" ] }, { diff --git a/scripts/experiments/icra/scene_parse/scene_parse.ipynb b/scripts/experiments/icra/scene_parse/scene_parse.ipynb index 1c09d914..7767b894 100644 --- a/scripts/experiments/icra/scene_parse/scene_parse.ipynb +++ b/scripts/experiments/icra/scene_parse/scene_parse.ipynb @@ -60,21 +60,22 @@ ], "source": [ "intrinsics = b.Intrinsics(\n", - " height=50,\n", - " width=50,\n", - " fx=250.0, fy=250.0,\n", - " cx=25.0, cy=25.0,\n", - " near=0.01, far=20.0\n", + " height=50, width=50, fx=250.0, fy=250.0, cx=25.0, cy=25.0, near=0.01, far=20.0\n", ")\n", "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -86,17 +87,26 @@ "source": [ "importance_jit = jax.jit(b.model.importance)\n", "\n", - "contact_enumerators = [b.make_enumerator([f\"contact_params_{i}\", \"variance\", \"outlier_prob\"]) for i in range(5)]\n", + "contact_enumerators = [\n", + " b.make_enumerator([f\"contact_params_{i}\", \"variance\", \"outlier_prob\"])\n", + " for i in range(5)\n", + "]\n", "add_object_jit = jax.jit(b.add_object)\n", "\n", - "def c2f_contact_update(trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):\n", + "\n", + "def c2f_contact_update(\n", + " trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID\n", + "):\n", " contact_param_grid = contact_param_deltas + trace_[f\"contact_params_{number}\"]\n", - " scores = contact_enumerators[number].enumerate_choices_get_scores(trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)\n", - " i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " scores = contact_enumerators[number].enumerate_choices_get_scores(\n", + " trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID\n", + " )\n", + " i, j, k = jnp.unravel_index(scores.argmax(), scores.shape)\n", " return contact_enumerators[number].update_choices(\n", - " trace_, key,\n", - " contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", + " trace_, key, contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", " )\n", + "\n", + "\n", "c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=(\"number\",))" ] }, @@ -112,17 +122,19 @@ "OUTLIER_GRID = jnp.array([0.00001, 0.0001, 0.001])\n", "\n", "grid_params = [\n", - " (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi, (11,11,11)),\n", - " (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.05, 0.0, (21,21,1))\n", + " (0.3, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi, (11, 11, 11)),\n", + " (0.1, jnp.pi, (11, 11, 11)),\n", + " (0.05, jnp.pi / 3, (11, 11, 11)),\n", + " (0.02, jnp.pi, (5, 5, 51)),\n", + " (0.01, jnp.pi / 5, (11, 11, 11)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.05, 0.0, (21, 21, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", - "]\n" + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", + "]" ] }, { @@ -148,7 +160,10 @@ " V_GRID = VARIANCE_GRID\n", " O_GRID = OUTLIER_GRID\n", "else:\n", - " V_GRID, O_GRID = jnp.array([VARIANCE_GRID[V_VARIANT]]), jnp.array([OUTLIER_GRID[O_VARIANT]])\n", + " V_GRID, O_GRID = (\n", + " jnp.array([VARIANCE_GRID[V_VARIANT]]),\n", + " jnp.array([OUTLIER_GRID[O_VARIANT]]),\n", + " )\n", "\n", "print(V_GRID, O_GRID)" ] @@ -198,32 +213,46 @@ " )\n", ")\n", "\n", - "weight, gt_trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"parent_2\": 0,\n", - " \"parent_3\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " \"face_parent_1\": 2,\n", - " \"face_parent_2\": 2,\n", - " \"face_parent_3\": 2,\n", - " \"face_child_1\": 3,\n", - " \"face_child_2\": 3,\n", - " \"face_child_3\": 3,\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - "}), (\n", - " jnp.arange(4),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, OUTLIER_VOLUME, 1.0)\n", + "weight, gt_trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"parent_2\": 0,\n", + " \"parent_3\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " \"face_parent_1\": 2,\n", + " \"face_parent_2\": 2,\n", + " \"face_parent_3\": 2,\n", + " \"face_child_1\": 3,\n", + " \"face_child_2\": 3,\n", + " \"face_child_3\": 3,\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(4),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " OUTLIER_VOLUME,\n", + " 1.0,\n", + " ),\n", ")\n", "print(gt_trace.get_score())\n", "\n", - "_,trace = importance_jit(key, gt_trace.get_choices(), (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:]))\n", + "_, trace = importance_jit(\n", + " key,\n", + " gt_trace.get_choices(),\n", + " (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:]),\n", + ")\n", "b.viz_trace_rendered_observed(trace)" ] }, @@ -286,23 +315,27 @@ "all_all_paths = []\n", "for _ in range(3):\n", " all_paths = []\n", - " for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)):\n", + " for obj_id in tqdm(range(len(b.RENDERER.meshes) - 1)):\n", " path = []\n", - " trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)\n", + " trace_ = add_object_jit(trace, key, obj_id, 0, 2, 3)\n", " number = b.get_contact_params(trace_).shape[0] - 1\n", " path.append(trace_)\n", " for c2f_iter in range(len(contact_param_gridding_schedule)):\n", - " trace_ = c2f_contact_update_jit(trace_, key, number,\n", - " contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)\n", + " trace_ = c2f_contact_update_jit(\n", + " trace_,\n", + " key,\n", + " number,\n", + " contact_param_gridding_schedule[c2f_iter],\n", + " V_GRID,\n", + " O_GRID,\n", + " )\n", " path.append(trace_)\n", " # for c2f_iter in range(len(contact_param_gridding_schedule)):\n", " # trace_ = c2f_contact_update_jit(trace_, key, number,\n", " # contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID)\n", - " all_paths.append(\n", - " path\n", - " )\n", + " all_paths.append(path)\n", " all_all_paths.append(all_paths)\n", - " \n", + "\n", " scores = jnp.array([t[-1].get_score() for t in all_paths])\n", " print(scores)\n", " normalized_scores = b.utils.normalize_log_scores(scores)\n", diff --git a/scripts/experiments/learning.ipynb b/scripts/experiments/learning.ipynb index c651517c..be82a74f 100644 --- a/scripts/experiments/learning.ipynb +++ b/scripts/experiments/learning.ipynb @@ -35,10 +35,12 @@ "metadata": {}, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "obj_idx = 4\n", - "mesh_filename = os.path.join(model_dir,\"obj_\" + \"{}\".format(obj_idx+1).rjust(6, '0') + \".ply\")\n", - "SCALING_FACTOR = 1.0/1000.0" + "mesh_filename = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(obj_idx + 1).rjust(6, \"0\") + \".ply\"\n", + ")\n", + "SCALING_FACTOR = 1.0 / 1000.0" ] }, { @@ -49,11 +51,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=500.0, fy=500.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=50.0\n", + " height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=50.0\n", ")\n", "b.setup_renderer(intrinsics)\n", "b.RENDERER.add_mesh_from_file(mesh_filename, scaling_factor=SCALING_FACTOR)" @@ -66,12 +64,20 @@ "metadata": {}, "outputs": [], "source": [ - "object_poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.6, 0.6]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) for angle in jnp.linspace(-jnp.pi, jnp.pi, 7)[:-1]])\n", - "observations = b.RENDERER.render_many(object_poses[:,None,...], jnp.array([0]))" + "object_poses = jnp.array(\n", + " [\n", + " b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 0.6, 0.6]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + " )\n", + " @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " for angle in jnp.linspace(-jnp.pi, jnp.pi, 7)[:-1]\n", + " ]\n", + ")\n", + "observations = b.RENDERER.render_many(object_poses[:, None, ...], jnp.array([0]))" ] }, { @@ -81,7 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.hstack_images([b.get_depth_image(o[...,2]) for o in observations])" + "b.hstack_images([b.get_depth_image(o[..., 2]) for o in observations])" ] }, { @@ -92,10 +98,16 @@ "outputs": [], "source": [ "grid = b.utils.make_translation_grid_enumeration_3d(\n", - " -0.1, -0.1, -0.2,\n", - " 0.1, 0.1, 0.2,\n", + " -0.1,\n", + " -0.1,\n", + " -0.2,\n", + " 0.1,\n", + " 0.1,\n", + " 0.2,\n", " # 100, 100, 100\n", - " 60, 60, 60\n", + " 60,\n", + " 60,\n", + " 60,\n", ")\n", "b.show_cloud(\"grid\", grid)" ] @@ -110,16 +122,27 @@ "def voxel_occupied_occluded_free(camera_pose, depth_image, grid, intrinsics, tolerance):\n", " grid_in_cam_frame = b.apply_transform(grid, b.t3d.inverse_pose(camera_pose))\n", " pixels = b.project_cloud_to_pixels(grid_in_cam_frame, intrinsics).astype(jnp.int32)\n", - " valid_pixels = (0 <= pixels[:,0]) * (0 <= pixels[:,1]) * (pixels[:,0] < intrinsics.width) * (pixels[:,1] < intrinsics.height)\n", - " real_depth_vals = depth_image[pixels[:,1],pixels[:,0]] * valid_pixels + (1 - valid_pixels) * (intrinsics.far + 1.0)\n", - " \n", - " projected_depth_vals = grid_in_cam_frame[:,2]\n", + " valid_pixels = (\n", + " (0 <= pixels[:, 0])\n", + " * (0 <= pixels[:, 1])\n", + " * (pixels[:, 0] < intrinsics.width)\n", + " * (pixels[:, 1] < intrinsics.height)\n", + " )\n", + " real_depth_vals = depth_image[pixels[:, 1], pixels[:, 0]] * valid_pixels + (\n", + " 1 - valid_pixels\n", + " ) * (intrinsics.far + 1.0)\n", + "\n", + " projected_depth_vals = grid_in_cam_frame[:, 2]\n", " occupied = jnp.abs(real_depth_vals - projected_depth_vals) < tolerance\n", " occluded = real_depth_vals < projected_depth_vals\n", " occluded = occluded * (1.0 - occupied)\n", " free = (1.0 - occluded) * (1.0 - occupied)\n", " return 1.0 * occupied + 0.5 * occluded\n", - "voxel_occupied_occluded_free_parallel = jax.jit(jax.vmap(voxel_occupied_occluded_free, in_axes=(0, 0, None, None, None)))" + "\n", + "\n", + "voxel_occupied_occluded_free_parallel = jax.jit(\n", + " jax.vmap(voxel_occupied_occluded_free, in_axes=(0, 0, None, None, None))\n", + ")" ] }, { @@ -130,7 +153,7 @@ "outputs": [], "source": [ "occupancies = voxel_occupied_occluded_free_parallel(\n", - " b.inverse_pose(object_poses), observations[...,2], grid, intrinsics, 0.001\n", + " b.inverse_pose(object_poses), observations[..., 2], grid, intrinsics, 0.001\n", ")\n", "print(occupancies.sum())" ] @@ -143,7 +166,7 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"grid\", grid[(occupancies > 0.6).sum(0) > 0 ])\n", + "b.show_cloud(\"grid\", grid[(occupancies > 0.6).sum(0) > 0])\n", "# b.show_cloud(\"grid2\", grid[occupancy == 0.5],color=b.RED)" ] }, @@ -154,7 +177,9 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = b.utils.make_voxel_mesh_from_point_cloud(grid[(occupancies > 0.6).sum(0) > 0 ], 0.005 )" + "mesh = b.utils.make_voxel_mesh_from_point_cloud(\n", + " grid[(occupancies > 0.6).sum(0) > 0], 0.005\n", + ")" ] }, { @@ -177,13 +202,10 @@ "source": [ "key = jax.random.PRNGKey(10)\n", "random_pose = b.distributions.gaussian_vmf_jit(\n", - " key,\n", - " b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])),\n", - " 0.01,\n", - " 0.01\n", + " key, b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])), 0.01, 0.01\n", ")\n", - "observation = b.RENDERER.render(random_pose[None,...], jnp.array([0]))[...,:3]\n", - "b.get_depth_image(observation[...,2])" + "observation = b.RENDERER.render(random_pose[None, ...], jnp.array([0]))[..., :3]\n", + "b.get_depth_image(observation[..., 2])" ] }, { @@ -193,10 +215,9 @@ "metadata": {}, "outputs": [], "source": [ - "sample_gaussian_vmf_jit = jax.jit(jax.vmap(\n", - " b.distributions.gaussian_vmf_jit,\n", - " in_axes=(0, None, None, None)\n", - "))" + "sample_gaussian_vmf_jit = jax.jit(\n", + " jax.vmap(b.distributions.gaussian_vmf_jit, in_axes=(0, None, None, None))\n", + ")" ] }, { @@ -208,16 +229,23 @@ "source": [ "@genjax.gen\n", "def single_object_model(variance, outlier_prob, outlier_volume):\n", - " pose = b.genjax.uniform_pose(jnp.array([-10.0,-10.0,-10.0]), jnp.array([10.0,10.0,10.0])) @ \"pose\"\n", - " rendered = b.RENDERER.render(\n", - " pose[None,...] , jnp.array([0])\n", - " )[...,:3]\n", - " image = b.genjax.image_likelihood(rendered, variance, outlier_prob, outlier_volume) @ \"image\"\n", + " pose = (\n", + " b.genjax.uniform_pose(\n", + " jnp.array([-10.0, -10.0, -10.0]), jnp.array([10.0, 10.0, 10.0])\n", + " )\n", + " @ \"pose\"\n", + " )\n", + " rendered = b.RENDERER.render(pose[None, ...], jnp.array([0]))[..., :3]\n", + " image = (\n", + " b.genjax.image_likelihood(rendered, variance, outlier_prob, outlier_volume)\n", + " @ \"image\"\n", + " )\n", " return rendered\n", "\n", + "\n", "importance_jit = jax.jit(single_object_model.importance)\n", "key = jax.random.PRNGKey(5)\n", - "enumerator = b.genjax.make_enumerator([\"pose\"]) " + "enumerator = b.genjax.make_enumerator([\"pose\"])" ] }, { @@ -228,9 +256,7 @@ "outputs": [], "source": [ "trace = importance_jit(\n", - " key,\n", - " genjax.choice_map({\"image\": observation}),\n", - " (0.001, 0.001, 1000.0)\n", + " key, genjax.choice_map({\"image\": observation}), (0.001, 0.001, 1000.0)\n", ")[1][1]" ] }, @@ -250,9 +276,9 @@ "outputs": [], "source": [ "keys = jax.random.split(key, 1000)\n", - "poses = sample_gaussian_vmf_jit(keys, b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])),\n", - " 0.01,\n", - " 0.01)" + "poses = sample_gaussian_vmf_jit(\n", + " keys, b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])), 0.01, 0.01\n", + ")" ] }, { @@ -283,8 +309,8 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"obs\", trace[\"image\"].reshape(-1,3))\n", - "b.show_cloud(\"render\", trace.get_retval().reshape(-1,3), color=b.RED)" + "b.show_cloud(\"obs\", trace[\"image\"].reshape(-1, 3))\n", + "b.show_cloud(\"render\", trace.get_retval().reshape(-1, 3), color=b.RED)" ] }, { @@ -295,9 +321,9 @@ "outputs": [], "source": [ "keys = jax.random.split(key, 5000)\n", - "poses = sample_gaussian_vmf_jit(keys, b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])),\n", - " 0.01,\n", - " 0.01)" + "poses = sample_gaussian_vmf_jit(\n", + " keys, b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])), 0.01, 0.01\n", + ")" ] }, { @@ -307,9 +333,15 @@ "metadata": {}, "outputs": [], "source": [ - "grid_over_pose = jax.jit(jax.vmap(\n", - " lambda trace,key, p: trace.update(key, genjax.choice_map({\"pose\": p}), tuple(map(lambda v: Diff(v, UnknownChange), trace.args)))\n", - ", in_axes=(None, None, 0))\n", + "grid_over_pose = jax.jit(\n", + " jax.vmap(\n", + " lambda trace, key, p: trace.update(\n", + " key,\n", + " genjax.choice_map({\"pose\": p}),\n", + " tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),\n", + " ),\n", + " in_axes=(None, None, 0),\n", + " )\n", ")" ] }, diff --git a/scripts/experiments/likelihood_debug/likelihood_debug.ipynb b/scripts/experiments/likelihood_debug/likelihood_debug.ipynb index a2d4ec7d..c00751c4 100644 --- a/scripts/experiments/likelihood_debug/likelihood_debug.ipynb +++ b/scripts/experiments/likelihood_debug/likelihood_debug.ipynb @@ -49,11 +49,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=200.0, fy=200.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.0001, far=2.0\n", + " height=100, width=100, fx=200.0, fy=200.0, cx=50.0, cy=50.0, near=0.0001, far=2.0\n", ")" ] }, @@ -79,15 +75,19 @@ } ], "source": [ - "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/10.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -102,22 +102,30 @@ "num_position_grids = 51\n", "num_angle_grids = 51\n", "contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(\n", - " -width, -width, -ang,\n", - " width, width, ang,\n", - " num_position_grids,num_position_grids,num_angle_grids\n", + " -width,\n", + " -width,\n", + " -ang,\n", + " width,\n", + " width,\n", + " ang,\n", + " num_position_grids,\n", + " num_position_grids,\n", + " num_angle_grids,\n", ")\n", "\n", "grid_params = [\n", - " (0.3, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),\n", - " (0.05, jnp.pi/3, (15,15,15)), (0.02, jnp.pi, (9,9,51)), (0.01, jnp.pi/5, (15,15,15)), (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))\n", + " (0.3, jnp.pi, (15, 15, 15)),\n", + " (0.2, jnp.pi, (15, 15, 15)),\n", + " (0.1, jnp.pi, (15, 15, 15)),\n", + " (0.05, jnp.pi / 3, (15, 15, 15)),\n", + " (0.02, jnp.pi, (9, 9, 51)),\n", + " (0.01, jnp.pi / 5, (15, 15, 15)),\n", + " (0.01, 0.0, (31, 31, 1)),\n", + " (0.05, 0.0, (31, 31, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]" ] }, @@ -130,17 +138,19 @@ "source": [ "table_pose = b.t3d.inverse_pose(\n", " b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.8, .15]),\n", + " jnp.array([0.0, 0.8, 0.15]),\n", " jnp.array([0.0, 0.0, 0.0]),\n", " jnp.array([0.0, 0.0, 1.0]),\n", " )\n", ")\n", "face_child = 3\n", - "cp_to_pose = lambda cp: table_pose@ b.scene_graph.relative_pose_from_edge(cp, face_child, b.RENDERER.model_box_dims[13])\n", + "cp_to_pose = lambda cp: table_pose @ b.scene_graph.relative_pose_from_edge(\n", + " cp, face_child, b.RENDERER.model_box_dims[13]\n", + ")\n", "cp_to_pose_jit = jax.jit(cp_to_pose)\n", "cp_to_pose_parallel = jax.jit(jax.vmap(cp_to_pose, in_axes=(0,)))\n", "\n", - "key = jax.random.PRNGKey(30)\n" + "key = jax.random.PRNGKey(30)" ] }, { @@ -161,29 +171,32 @@ "def score_images(rendered, observed):\n", " return -jnp.linalg.norm(observed - rendered, axis=-1).mean()\n", "\n", + "\n", "# def score_images(rendered, observed):\n", "# mask = observed[...,2] < intrinsics.far\n", "# return (jnp.linalg.norm(observed - rendered, axis=-1)* (1.0 * mask)).sum() / mask.sum()\n", "\n", + "\n", "def score_images(rendered, observed):\n", " return -jnp.linalg.norm(observed - rendered, axis=-1).mean()\n", "\n", + "\n", "def score_images(rendered, observed):\n", " distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", " probabilities_per_pixel = jax.scipy.stats.norm.logpdf(\n", - " distances,\n", - " loc=0.0, \n", - " scale=0.02\n", + " distances, loc=0.0, scale=0.02\n", " )\n", " image_probability = probabilities_per_pixel.mean()\n", " return image_probability\n", "\n", + "\n", "def score_images(rendered, observed):\n", " distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", " width = 0.02\n", - " probabilities_per_pixel = (distances < width/2) / width\n", + " probabilities_per_pixel = (distances < width / 2) / width\n", " return probabilities_per_pixel.mean()\n", "\n", + "\n", "score_vmap = jax.jit(jax.vmap(score_images, in_axes=(0, None)))" ] }, @@ -206,13 +219,13 @@ } ], "source": [ - "key = jax.random.split(key,2)[0]\n", + "key = jax.random.split(key, 2)[0]\n", "key = jnp.array([2755247810, 1586593754], dtype=np.uint32)\n", "low, high = jnp.array([-0.2, -0.2, -jnp.pi]), jnp.array([0.2, 0.2, jnp.pi])\n", - "gt_cp = jax.random.uniform(key, shape=(3,),minval=low, maxval=high)\n", + "gt_cp = jax.random.uniform(key, shape=(3,), minval=low, maxval=high)\n", "gt_pose = cp_to_pose_jit(gt_cp)\n", - "obs_img = b.RENDERER.render(gt_pose[None,...], jnp.array([13]))[...,:3]\n", - "b.viz.scale_image(b.get_depth_image(obs_img[...,2]),3.0)" + "obs_img = b.RENDERER.render(gt_pose[None, ...], jnp.array([13]))[..., :3]\n", + "b.viz.scale_image(b.get_depth_image(obs_img[..., 2]), 3.0)" ] }, { @@ -254,10 +267,18 @@ ], "source": [ "contact_param_grid = gt_cp + contact_param_deltas\n", - "scores = jnp.concatenate([\n", - " score_vmap(b.RENDERER.render_many(cp_to_pose_parallel(cps)[:,None,...], jnp.array([13]))[...,:3], obs_img)\n", - " for cps in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)\n", + "scores = jnp.concatenate(\n", + " [\n", + " score_vmap(\n", + " b.RENDERER.render_many(\n", + " cp_to_pose_parallel(cps)[:, None, ...], jnp.array([13])\n", + " )[..., :3],\n", + " obs_img,\n", + " )\n", + " for cps in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")\n", "\n", "sort_order = jnp.argsort(-scores)\n", "sorted_scores = scores[sort_order]\n", @@ -265,11 +286,13 @@ "print(\"GT CP: \", gt_cp)\n", "print(sorted_scores[:k])\n", "print(contact_param_grid[sort_order[:k]])\n", - "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:,None,...]\n", - "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[...,:3]\n", + "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:, None, ...]\n", + "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[..., :3]\n", "\n", "\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -307,8 +330,8 @@ } ], "source": [ - "observed=obs_img\n", - "rendered=rendered_top_k[1]\n", + "observed = obs_img\n", + "rendered = rendered_top_k[1]\n", "distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", "matches = distances < 0.01\n", "fig = plt.figure()\n", @@ -316,10 +339,10 @@ "print(matches.sum())\n", "ax.imshow(matches)\n", "ax = fig.add_subplot(132)\n", - "ax.imshow(b.viz.preprocess_for_viz(observed[...,2]))\n", + "ax.imshow(b.viz.preprocess_for_viz(observed[..., 2]))\n", "ax = fig.add_subplot(133)\n", - "ax.imshow(b.viz.preprocess_for_viz(observed[...,2]),alpha=0.5)\n", - "ax.imshow(matches,alpha=0.5)" + "ax.imshow(b.viz.preprocess_for_viz(observed[..., 2]), alpha=0.5)\n", + "ax.imshow(matches, alpha=0.5)" ] }, { @@ -350,10 +373,10 @@ } ], "source": [ - "diffs = rendered_top_k[:,:,:,2] - obs_img[:,:,2][None,...]\n", + "diffs = rendered_top_k[:, :, :, 2] - obs_img[:, :, 2][None, ...]\n", "i = 1\n", "\n", - "plt.imshow(jnp.abs(diffs[i]) > 0.01,alpha=0.5)\n", + "plt.imshow(jnp.abs(diffs[i]) > 0.01, alpha=0.5)\n", "plt.colorbar()" ] }, @@ -382,23 +405,28 @@ "for cp_grid in contact_param_gridding_schedule:\n", " contact_param_grid = cp + cp_grid\n", " cp_poses = cp_to_pose_parallel(contact_param_grid)\n", - " scores = score_vmap(b.RENDERER.render_many(cp_poses[:,None,...], jnp.array([13]))[...,:3], obs_img)\n", + " scores = score_vmap(\n", + " b.RENDERER.render_many(cp_poses[:, None, ...], jnp.array([13]))[..., :3],\n", + " obs_img,\n", + " )\n", " best_idx = jnp.argmax(scores)\n", " cp = contact_param_grid[best_idx]\n", " cps.append(cp)\n", " poses.append(cp_poses[best_idx])\n", - "images_over_time = b.RENDERER.render_many(cp_to_pose_parallel(jnp.stack(cps))[:,None,...], jnp.array([13]))[...,:3]\n", + "images_over_time = b.RENDERER.render_many(\n", + " cp_to_pose_parallel(jnp.stack(cps))[:, None, ...], jnp.array([13])\n", + ")[..., :3]\n", "\n", "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", "colors = b.distinct_colors(len(images_over_time))\n", "b.show_trimesh(f\"gt\", b.RENDERER.meshes[13], opacity=0.5, color=b.RED)\n", "b.set_pose(f\"gt\", gt_pose)\n", "for i in range(len(images_over_time)):\n", " b.show_trimesh(f\"_{i}\", b.RENDERER.meshes[13], opacity=0.5, color=colors[i])\n", " b.set_pose(f\"_{i}\", poses[i])\n", - " \n", - "b.hstack_images([b.get_depth_image(i[...,2]) for i in images_over_time])" + "\n", + "b.hstack_images([b.get_depth_image(i[..., 2]) for i in images_over_time])" ] }, { @@ -421,10 +449,18 @@ ], "source": [ "contact_param_grid = cp + contact_param_deltas\n", - "scores = jnp.concatenate([\n", - " score_vmap(b.RENDERER.render_many(cp_to_pose_parallel(cps)[:,None,...], jnp.array([13]))[...,:3], obs_img)\n", - " for cps in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)\n", + "scores = jnp.concatenate(\n", + " [\n", + " score_vmap(\n", + " b.RENDERER.render_many(\n", + " cp_to_pose_parallel(cps)[:, None, ...], jnp.array([13])\n", + " )[..., :3],\n", + " obs_img,\n", + " )\n", + " for cps in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")\n", "\n", "sort_order = jnp.argsort(-scores)\n", "sorted_scores = scores[sort_order]\n", @@ -432,11 +468,13 @@ "# print(\"GT CP: \", gt_cp)\n", "# print(sorted_scores[:k])\n", "# print(contact_param_grid[sort_order[:k]])\n", - "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:,None,...]\n", - "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[...,:3]\n", + "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:, None, ...]\n", + "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[..., :3]\n", "\n", "\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -497,9 +535,11 @@ ], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", - "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1,3), color=b.RED)\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", + "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1, 3), color=b.RED)\n", + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -522,9 +562,11 @@ ], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", - "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1,3), color=b.RED)\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", + "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1, 3), color=b.RED)\n", + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -535,10 +577,10 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", "colors = b.distinct_colors(len(images_over_time))\n", "for i in range(len(images_over_time)):\n", - " b.show_cloud(f\"_{i}\", images_over_time[i].reshape(-1,3), color=colors[i])" + " b.show_cloud(f\"_{i}\", images_over_time[i].reshape(-1, 3), color=colors[i])" ] }, { @@ -723,10 +765,13 @@ ], "source": [ "contact_param_grid = trace[\"contact_params_1\"] + contact_param_deltas\n", - "images = jnp.concatenate([\n", - " enumerators[2](trace, key, cp)[\"image\"]\n", - " for cp in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)" + "images = jnp.concatenate(\n", + " [\n", + " enumerators[2](trace, key, cp)[\"image\"]\n", + " for cp in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")" ] }, { @@ -909,27 +954,45 @@ " # fig.suptitle(f\"Variance: {variance} Outlier Prob: {outlier_prob}\")\n", " widths = [1, 1]\n", " heights = [2]\n", - " spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", - " \n", + " spec = fig.add_gridspec(\n", + " ncols=2, nrows=1, width_ratios=widths, height_ratios=heights\n", + " )\n", + "\n", " ax = fig.add_subplot(spec[0, 0])\n", - " ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))\n", + " ax.imshow(jnp.array(b.get_depth_image(observation[..., 2], max=1.4)))\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.set_title(f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\")\n", + " ax.set_title(\n", + " f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\"\n", + " )\n", " # ax.set_title(f\"Observed Depth\")\n", - " \n", - " \n", + "\n", " ax = fig.add_subplot(spec[0, 1])\n", " ax.set_aspect(1.0)\n", - " circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle=\"--\", linewidth=0.5)\n", + " circ = plt.Circle(\n", + " (0, 0),\n", + " radius=1,\n", + " edgecolor=\"black\",\n", + " facecolor=\"None\",\n", + " linestyle=\"--\",\n", + " linewidth=0.5,\n", + " )\n", " ax.add_patch(circ)\n", " ax.set_xlim(-1.1, 1.1)\n", " ax.set_ylim(-1.1, 1.1)\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.scatter(-jnp.sin(sampled_contacts[:,2]),jnp.cos(sampled_contacts[:,2]), color='red',label=\"Posterior Samples\", alpha=0.5, s=30)\n", - " ax.scatter(-jnp.sin(gt_contact[2]),jnp.cos(gt_contact[2]), label=\"Actual\", alpha=0.9, s=25)\n", + " ax.scatter(\n", + " -jnp.sin(sampled_contacts[:, 2]),\n", + " jnp.cos(sampled_contacts[:, 2]),\n", + " color=\"red\",\n", + " label=\"Posterior Samples\",\n", + " alpha=0.5,\n", + " s=30,\n", + " )\n", + " ax.scatter(\n", + " -jnp.sin(gt_contact[2]), jnp.cos(gt_contact[2]), label=\"Actual\", alpha=0.9, s=25\n", + " )\n", " ax.set_title(\"Posterior on Orientation (top view)\")\n", " # ax.legend(fontsize=9)\n", " # plt.show()\n", @@ -947,10 +1010,9 @@ " contact_param_grid = contact_param_deltas + trace_[f\"contact_params_1\"]\n", " scores = enumerators[3](trace_, key, contact_param_grid)\n", " i = scores.argmax()\n", - " return enumerators[0](\n", - " trace_, key,\n", - " contact_param_grid[i]\n", - " )\n", + " return enumerators[0](trace_, key, contact_param_grid[i])\n", + "\n", + "\n", "c2f_contact_update_jit = jax.jit(c2f_contact_update)" ] }, @@ -973,16 +1035,18 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.3, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),\n", - " (0.05, jnp.pi/3, (15,15,15)), (0.02, jnp.pi, (9,9,51)), (0.01, jnp.pi/5, (15,15,15)), (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))\n", + " (0.3, jnp.pi, (15, 15, 15)),\n", + " (0.2, jnp.pi, (15, 15, 15)),\n", + " (0.1, jnp.pi, (15, 15, 15)),\n", + " (0.05, jnp.pi / 3, (15, 15, 15)),\n", + " (0.02, jnp.pi, (9, 9, 51)),\n", + " (0.01, jnp.pi / 5, (15, 15, 15)),\n", + " (0.01, 0.0, (31, 31, 1)),\n", + " (0.05, 0.0, (31, 31, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]" ] }, @@ -993,7 +1057,7 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.split(key,2)[0]" + "key = jax.random.split(key, 2)[0]" ] }, { @@ -1023,24 +1087,36 @@ ], "source": [ "low, high = jnp.array([-0.2, -0.2, -jnp.pi]), jnp.array([0.2, 0.2, jnp.pi])\n", - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"id_1\": jnp.int32(13),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " \"face_parent_1\": 3,\n", - " \"face_child_1\": 2,\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jax.random.uniform(key, shape=(3,),minval=low, maxval=high)\n", - "}), (\n", - " jnp.arange(2),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.5, -0.5, -2*jnp.pi]), jnp.array([0.5, 0.5, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, 1.0, intrinsics.fx)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"id_1\": jnp.int32(13),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " \"face_parent_1\": 3,\n", + " \"face_child_1\": 2,\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jax.random.uniform(\n", + " key, shape=(3,), minval=low, maxval=high\n", + " ),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(2),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.5, -0.5, -2 * jnp.pi]), jnp.array([0.5, 0.5, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " 1.0,\n", + " intrinsics.fx,\n", + " ),\n", ")\n", "gt_poses = b.get_poses(trace)\n", "gt_contact = trace[\"contact_params_1\"]\n", @@ -1081,7 +1157,9 @@ "path = []\n", "path.append(trace)\n", "for c2f_iter in range(len(contact_param_gridding_schedule)):\n", - " trace = c2f_contact_update_jit(trace, key, contact_param_gridding_schedule[c2f_iter])\n", + " trace = c2f_contact_update_jit(\n", + " trace, key, contact_param_gridding_schedule[c2f_iter]\n", + " )\n", " path.append(trace)\n", "print(trace[\"contact_params_1\"])\n", "b.viz_trace_rendered_observed(trace)" @@ -1105,10 +1183,10 @@ "source": [ "%%time\n", "contact_param_grid = trace[\"contact_params_1\"] + contact_param_deltas\n", - "weights = jnp.concatenate([\n", - " enumerators[3](trace, key, cp)\n", - " for cp in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)" + "weights = jnp.concatenate(\n", + " [enumerators[3](trace, key, cp) for cp in jnp.array_split(contact_param_grid, 15)],\n", + " axis=0,\n", + ")" ] }, { @@ -1120,7 +1198,9 @@ "source": [ "key2 = jax.random.split(key, 1)[0]\n", "normalized_weights = b.utils.normalize_log_scores(weights)\n", - "sampled_indices = jax.random.choice(key2,jnp.arange(normalized_weights.shape[0]), shape=(2000,), p=normalized_weights)\n", + "sampled_indices = jax.random.choice(\n", + " key2, jnp.arange(normalized_weights.shape[0]), shape=(2000,), p=normalized_weights\n", + ")\n", "sampled_contact_params = contact_param_grid[sampled_indices]" ] }, @@ -1170,21 +1250,21 @@ ], "source": [ "fig = plt.figure()\n", - "ax = fig.add_subplot(projection='3d')\n", + "ax = fig.add_subplot(projection=\"3d\")\n", "ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "# make the grid lines transparent\n", - "ax.xaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "ax.yaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "ax.zaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "u, v = np.mgrid[0:2*np.pi:21j, 0:np.pi:11j]\n", - "x = np.cos(u)*np.sin(v)\n", - "y = np.sin(u)*np.sin(v)\n", + "ax.xaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "ax.yaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "ax.zaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "u, v = np.mgrid[0 : 2 * np.pi : 21j, 0 : np.pi : 11j]\n", + "x = np.cos(u) * np.sin(v)\n", + "y = np.sin(u) * np.sin(v)\n", "z = np.cos(v)\n", - "ax.axes.set_xlim3d(-1.1, 1.1) \n", - "ax.axes.set_ylim3d(-1.1, 1.1) \n", - "ax.axes.set_zlim3d(-1.1, 1.1) \n", + "ax.axes.set_xlim3d(-1.1, 1.1)\n", + "ax.axes.set_ylim3d(-1.1, 1.1)\n", + "ax.axes.set_zlim3d(-1.1, 1.1)\n", "ax.set_aspect(\"equal\")\n", "ax.plot_wireframe(x, y, z, color=(0.0, 0.0, 0.0, 0.3), linewidths=0.5)\n", "\n", @@ -1195,15 +1275,25 @@ "\n", "points = []\n", "NUM = 1\n", - "offset = jnp.pi/2\n", + "offset = jnp.pi / 2\n", "scaling = 0.96\n", "for i in sampled_contact_params:\n", - " points.append(np.array([np.cos(i[2] + offset) * scaling, np.sin(i[2] + offset) * scaling,0.0]))\n", + " points.append(\n", + " np.array(\n", + " [np.cos(i[2] + offset) * scaling, np.sin(i[2] + offset) * scaling, 0.0]\n", + " )\n", + " )\n", "points = np.array(points)\n", "\n", "z = 0.1\n", - "for i in np.arange(.1,1.01,.1):\n", - " ax.scatter(points[:,0], points[:,1],points[:,2], s=(40*i*(z*.9+.1))**2, color=(1,0,0,.3/i/10))\n", + "for i in np.arange(0.1, 1.01, 0.1):\n", + " ax.scatter(\n", + " points[:, 0],\n", + " points[:, 1],\n", + " points[:, 2],\n", + " s=(40 * i * (z * 0.9 + 0.1)) ** 2,\n", + " color=(1, 0, 0, 0.3 / i / 10),\n", + " )\n", "# offset = jnp.pi/2\n", "# angle = jnp.pi/4 - jnp.pi/4 - jnp.pi/4 - jnp.pi/4\n", "# for i in np.arange(.1,1.01,.1):\n", @@ -1223,13 +1313,18 @@ "scaled_up_intrinsics = b.scale_camera_parameters(intrinsics, 4)\n", "\n", "b.setup_renderer(scaled_up_intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/10.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -1249,14 +1344,14 @@ "metadata": {}, "outputs": [], "source": [ - "depth = img[...,2]\n", + "depth = img[..., 2]\n", "minval = jnp.min(depth[depth > jnp.min(depth)])\n", "maxval = jnp.max(depth[depth < jnp.max(depth)])\n", "depth = depth.at[depth >= intrinsics.far].set(jnp.nan)\n", - "viz_img = np.array(b.viz.scale_image(b.get_depth_image(\n", - " depth, min=minval, max=maxval\n", - "), 3))\n", - "viz_img[viz_img.sum(-1) == 0,:] = 255.0\n", + "viz_img = np.array(\n", + " b.viz.scale_image(b.get_depth_image(depth, min=minval, max=maxval), 3)\n", + ")\n", + "viz_img[viz_img.sum(-1) == 0, :] = 255.0\n", "plt.imshow(viz_img)\n", "plt.xticks([])\n", "plt.yticks([])\n", @@ -1299,7 +1394,7 @@ "metadata": {}, "outputs": [], "source": [ - "jnp.linalg.norm(points,axis=-1)" + "jnp.linalg.norm(points, axis=-1)" ] }, { @@ -1319,48 +1414,57 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "fig = plt.figure(constrained_layout=True)\n", "\n", - "observation = trace[\"image\"]\n", + "observation = trace[\"image\"]\n", "\n", "# fig.suptitle(f\"Variance: {variance} Outlier Prob: {outlier_prob}\")\n", "widths = [1, 1]\n", "heights = [2]\n", - "spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", + "spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths, height_ratios=heights)\n", "\n", "ax = fig.add_subplot(spec[0, 0])\n", - "ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))\n", + "ax.imshow(jnp.array(b.get_depth_image(observation[..., 2], max=1.4)))\n", "ax.get_xaxis().set_visible(False)\n", "ax.get_yaxis().set_visible(False)\n", - "ax.set_title(f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\")\n", + "ax.set_title(\n", + " f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\"\n", + ")\n", "# ax.set_title(f\"Observed Depth\")\n", "\n", "\n", "dist = 0.6\n", "ax = fig.add_subplot(spec[0, 1])\n", "ax.quiver(\n", - " sampled_contact_params[:,0],sampled_contact_params[:,1],\n", - " -jnp.sin(sampled_contact_params[:,2]),jnp.cos(sampled_contact_params[:,2]),\n", + " sampled_contact_params[:, 0],\n", + " sampled_contact_params[:, 1],\n", + " -jnp.sin(sampled_contact_params[:, 2]),\n", + " jnp.cos(sampled_contact_params[:, 2]),\n", " scale=3.0,\n", - " alpha=0.1\n", - " )\n", + " alpha=0.1,\n", + ")\n", "\n", "ax.quiver(\n", - " gt_contact[0],gt_contact[1],\n", - " -jnp.sin(gt_contact[2]), jnp.cos(gt_contact[2]),\n", + " gt_contact[0],\n", + " gt_contact[1],\n", + " -jnp.sin(gt_contact[2]),\n", + " jnp.cos(gt_contact[2]),\n", " scale=5.0,\n", " alpha=0.8,\n", - " color=\"red\"\n", + " color=\"red\",\n", ")\n", "\n", "ax.set_aspect(1.0)\n", "from matplotlib.patches import Rectangle\n", - "ax.add_patch(Rectangle((gt_contact[0]-width, gt_contact[1]-width), 2*width, 2*width,fill=None))\n", "\n", - "ax.set_xlim(gt_contact[0]-width-0.02, gt_contact[0]+width+0.02)\n", - "ax.set_ylim(gt_contact[1]-width-0.02, gt_contact[1]+width+0.02)" + "ax.add_patch(\n", + " Rectangle(\n", + " (gt_contact[0] - width, gt_contact[1] - width), 2 * width, 2 * width, fill=None\n", + " )\n", + ")\n", + "\n", + "ax.set_xlim(gt_contact[0] - width - 0.02, gt_contact[0] + width + 0.02)\n", + "ax.set_ylim(gt_contact[1] - width - 0.02, gt_contact[1] + width + 0.02)" ] }, { @@ -1375,8 +1479,8 @@ "best_cell_idx = jnp.abs(contact_param_grid - gt_contact).sum(1).argmin()\n", "print(gt_contact, contact_param_grid[best_cell_idx])\n", "normalize_log_weights = w1eights - b.logsumexp(weights)\n", - "assert(weights.shape[0] == contact_param_grid.shape[0])\n", - "volume = (width / num_position_grids)**2 * (2*jnp.pi / num_angle_grids)\n", + "assert weights.shape[0] == contact_param_grid.shape[0]\n", + "volume = (width / num_position_grids) ** 2 * (2 * jnp.pi / num_angle_grids)\n", "log_likelihood = normalize_log_weights[best_cell_idx] - jnp.log(volume)\n", "print(log_likelihood)" ] diff --git a/scripts/experiments/likelihood_debug/real_scene_parse.ipynb b/scripts/experiments/likelihood_debug/real_scene_parse.ipynb index 2a81e279..385da94e 100644 --- a/scripts/experiments/likelihood_debug/real_scene_parse.ipynb +++ b/scripts/experiments/likelihood_debug/real_scene_parse.ipynb @@ -21,9 +21,11 @@ "import glob\n", "import bayes3d.neural\n", "import pickle\n", + "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", + "# jax.config.update('jax_enable_checks', True)\n", "from bayes3d.neural.segmentation import carvekit_get_foreground_mask\n", + "\n", "# import genjax\n", "from tqdm import tqdm" ] @@ -50,12 +52,11 @@ { "cell_type": "code", "execution_count": 3, + "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "outputs": [], "source": [ - "paths = glob.glob(\n", - " \"panda_scans_v6/*.pkl\"\n", - ")\n", + "paths = glob.glob(\"panda_scans_v6/*.pkl\")\n", "all_data = pickle.load(open(paths[0], \"rb\"))\n", "IDX = 0\n", "data = all_data[IDX]" @@ -64,6 +65,7 @@ { "cell_type": "code", "execution_count": 4, + "id": "acae54e37e7d407bbb7b55eff062a284", "metadata": {}, "outputs": [ { @@ -87,21 +89,24 @@ ], "source": [ "print(data[\"camera_image\"].keys())\n", - "K = data[\"camera_image\"]['camera_matrix'][0]\n", - "rgb = data[\"camera_image\"]['rgbPixels']\n", - "depth = data[\"camera_image\"]['depthPixels']\n", - "camera_pose = data[\"camera_image\"]['camera_pose']\n", + "K = data[\"camera_image\"][\"camera_matrix\"][0]\n", + "rgb = data[\"camera_image\"][\"rgbPixels\"]\n", + "depth = data[\"camera_image\"][\"depthPixels\"]\n", + "camera_pose = data[\"camera_image\"][\"camera_pose\"]\n", "camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)\n", - "fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]\n", - "h,w = depth.shape\n", + "fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n", + "h, w = depth.shape\n", "near = 0.001\n", - "rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))\n", + "rgbd_original = b.RGBD(\n", + " rgb, depth, camera_pose, b.Intrinsics(h, w, fx, fy, cx, cy, 0.001, 10000.0)\n", + ")\n", "b.get_rgb_image(rgbd_original.rgb)" ] }, { "cell_type": "code", "execution_count": 5, + "id": "9a63283cbaf04dbcab1f6479b197f3a8", "metadata": {}, "outputs": [ { @@ -123,6 +128,7 @@ { "cell_type": "code", "execution_count": 6, + "id": "8dd0d8092fe74a7c96281538738b07e2", "metadata": {}, "outputs": [], "source": [ @@ -133,13 +139,17 @@ { "cell_type": "code", "execution_count": 7, + "id": "72eea5119410473aa328ad9291626812", "metadata": {}, "outputs": [], "source": [ "table_pose, plane_dims = b.utils.infer_table_plane(\n", - " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics),jnp.eye(4),\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics),\n", + " jnp.eye(4),\n", " rgbd_scaled_down.intrinsics,\n", - " ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=0.1\n", + " ransac_threshold=0.001,\n", + " inlier_threshold=0.001,\n", + " segmentation_threshold=0.1,\n", ")\n", "face_child = 3" ] @@ -147,17 +157,24 @@ { "cell_type": "code", "execution_count": 9, + "id": "8edb47106e1a46a883d545849b8ab81b", "metadata": {}, "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", table_pose)" ] }, { "cell_type": "code", "execution_count": 10, + "id": "10185d26023b46108eb7d9f57d49d2b3", "metadata": {}, "outputs": [ { @@ -179,12 +196,15 @@ "source": [ "b.setup_renderer(rgbd_scaled_down.intrinsics)\n", "b.RENDERER.add_mesh_from_file(\"toy_plane.ply\")\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(13+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(10+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)" + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(13 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(10 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -199,25 +219,30 @@ "num_position_grids = 51\n", "num_angle_grids = 51\n", "contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(\n", - " -width, -width, -ang,\n", - " width, width, ang,\n", - " num_position_grids,num_position_grids,num_angle_grids\n", + " -width,\n", + " -width,\n", + " -ang,\n", + " width,\n", + " width,\n", + " ang,\n", + " num_position_grids,\n", + " num_position_grids,\n", + " num_angle_grids,\n", ")\n", "\n", "grid_params = [\n", - " (0.5, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),\n", - " (0.05, jnp.pi/3, (15,15,15)),\n", - " (0.02, jnp.pi, (9,9,51))\n", - " , (0.01, jnp.pi/5, (15,15,15)),\n", - " (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))\n", + " (0.5, jnp.pi, (15, 15, 15)),\n", + " (0.2, jnp.pi, (15, 15, 15)),\n", + " (0.1, jnp.pi, (15, 15, 15)),\n", + " (0.05, jnp.pi / 3, (15, 15, 15)),\n", + " (0.02, jnp.pi, (9, 9, 51)),\n", + " (0.01, jnp.pi / 5, (15, 15, 15)),\n", + " (0.01, 0.0, (31, 31, 1)),\n", + " (0.05, 0.0, (31, 31, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]" ] }, @@ -228,47 +253,68 @@ "metadata": {}, "outputs": [], "source": [ - "_cp_to_pose = lambda cp, index: table_pose @ b.scene_graph.relative_pose_from_edge(cp, face_child, b.RENDERER.model_box_dims[index])\n", - "cps_to_pose= jax.vmap(_cp_to_pose, in_axes=(0,0,))\n", + "_cp_to_pose = lambda cp, index: table_pose @ b.scene_graph.relative_pose_from_edge(\n", + " cp, face_child, b.RENDERER.model_box_dims[index]\n", + ")\n", + "cps_to_pose = jax.vmap(\n", + " _cp_to_pose,\n", + " in_axes=(\n", + " 0,\n", + " 0,\n", + " ),\n", + ")\n", "cps_to_pose_jit = jax.jit(cps_to_pose)\n", - "cps_to_pose_parallel = jax.vmap(cps_to_pose, in_axes=(0,None,))\n", + "cps_to_pose_parallel = jax.vmap(\n", + " cps_to_pose,\n", + " in_axes=(\n", + " 0,\n", + " None,\n", + " ),\n", + ")\n", "cps_to_pose_parallel_jit = jax.jit(cps_to_pose_parallel)\n", "\n", "key = jax.random.PRNGKey(30)\n", "\n", + "\n", "def score_images(rendered, observed):\n", " return -jnp.linalg.norm(observed - rendered, axis=-1).mean()\n", "\n", + "\n", "# def score_images(rendered, observed):\n", "# mask = observed[...,2] < intrinsics.far\n", "# return (jnp.linalg.norm(observed - rendered, axis=-1)* (1.0 * mask)).sum() / mask.sum()\n", "\n", + "\n", "def score_images(rendered, observed):\n", " return -jnp.linalg.norm(observed - rendered, axis=-1).mean()\n", "\n", + "\n", "# def score_images(rendered, observed):\n", "# distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", "# probabilities_per_pixel = jax.scipy.stats.norm.pdf(\n", "# distances,\n", - "# loc=0.0, \n", + "# loc=0.0,\n", "# scale=0.02\n", "# )\n", "# image_probability = probabilities_per_pixel.mean()\n", "# return image_probability\n", "\n", + "\n", "def score_images(rendered, observed):\n", " distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", " width = 0.02\n", - " probabilities_per_pixel = (distances < width/2) / width\n", + " probabilities_per_pixel = (distances < width / 2) / width\n", " return probabilities_per_pixel.mean()\n", "\n", + "\n", "score_vmap = jax.jit(jax.vmap(score_images, in_axes=(0, None)))\n", "\n", + "\n", "def grid_and_max(cps, indices, number, grid, obs_img):\n", - " cps_expanded = jnp.repeat(cps[None,...], grid.shape[0], axis=0)\n", - " cps_expanded = cps_expanded.at[:,number,:].set(cps_expanded[:,number,:] + grid)\n", + " cps_expanded = jnp.repeat(cps[None, ...], grid.shape[0], axis=0)\n", + " cps_expanded = cps_expanded.at[:, number, :].set(cps_expanded[:, number, :] + grid)\n", " cp_poses = cps_to_pose_parallel(cps_expanded, indices)\n", - " rendered_images = b.RENDERER.render_many(cp_poses, indices)[...,:3]\n", + " rendered_images = b.RENDERER.render_many(cp_poses, indices)[..., :3]\n", " scores = score_vmap(rendered_images, obs_img)\n", " best_idx = jnp.argmax(scores)\n", " cps = cps_expanded[best_idx]\n", @@ -282,16 +328,23 @@ "metadata": {}, "outputs": [], "source": [ - "def c2f(potential_cps, potential_indices, number, contact_param_gridding_schedule, obs_img):\n", + "def c2f(\n", + " potential_cps, potential_indices, number, contact_param_gridding_schedule, obs_img\n", + "):\n", " for cp_grid in contact_param_gridding_schedule:\n", - " potential_cps, score = grid_and_max(potential_cps, potential_indices, number, cp_grid, obs_img)\n", + " potential_cps, score = grid_and_max(\n", + " potential_cps, potential_indices, number, cp_grid, obs_img\n", + " )\n", " return potential_cps, score\n", + "\n", + "\n", "c2f_jit = jax.jit(c2f)" ] }, { "cell_type": "code", "execution_count": 14, + "id": "8763a12b2bbd4a93a75aff182afb95dc", "metadata": {}, "outputs": [], "source": [ @@ -301,16 +354,18 @@ { "cell_type": "code", "execution_count": 25, + "id": "7623eae2785240b9bd12b16a66d81610", "metadata": {}, "outputs": [], "source": [ - "cps = jnp.zeros((0,3))\n", + "cps = jnp.zeros((0, 3))\n", "indices = jnp.array([], dtype=jnp.int32)" ] }, { "cell_type": "code", "execution_count": 27, + "id": "7cdc8c89c7104fffa095e18ddfef8986", "metadata": {}, "outputs": [ { @@ -325,19 +380,36 @@ } ], "source": [ - "\n", - "\n", "best_score = 0.0\n", "best_index = -1\n", "best_cps = None\n", "best_indices = None\n", - "key = jax.random.split(key,2)[0]\n", + "key = jax.random.split(key, 2)[0]\n", "low, high = jnp.array([-0.4, -0.4, -jnp.pi]), jnp.array([0.4, 0.4, jnp.pi])\n", "\n", "for next_index in range(len(b.RENDERER.model_box_dims)):\n", " potential_indices = jnp.concatenate([indices, jnp.array([next_index])])\n", - " potential_cps = jnp.concatenate([cps, jax.random.uniform(key, shape=(1,3,),minval=low, maxval=high)])\n", - " potential_cps, score = c2f_jit(potential_cps, potential_indices, len(potential_indices) - 1, contact_param_gridding_schedule, obs_img)\n", + " potential_cps = jnp.concatenate(\n", + " [\n", + " cps,\n", + " jax.random.uniform(\n", + " key,\n", + " shape=(\n", + " 1,\n", + " 3,\n", + " ),\n", + " minval=low,\n", + " maxval=high,\n", + " ),\n", + " ]\n", + " )\n", + " potential_cps, score = c2f_jit(\n", + " potential_cps,\n", + " potential_indices,\n", + " len(potential_indices) - 1,\n", + " contact_param_gridding_schedule,\n", + " obs_img,\n", + " )\n", " print(score)\n", " if score > best_score:\n", " best_index = next_index\n", @@ -348,7 +420,7 @@ "indices = best_indices\n", "\n", "b.clear()\n", - "b.show_cloud(\"obs\", obs_img.reshape(-1,3))\n", + "b.show_cloud(\"obs\", obs_img.reshape(-1, 3))\n", "b.show_pose(\"table\", table_pose)\n", "poses = cps_to_pose(cps, indices)\n", "for i in range(len(poses)):\n", @@ -359,6 +431,7 @@ { "cell_type": "code", "execution_count": null, + "id": "b118ea5561624da68c537baed56e602f", "metadata": {}, "outputs": [], "source": [] diff --git a/scripts/experiments/likelihood_debug/scene_parse.ipynb b/scripts/experiments/likelihood_debug/scene_parse.ipynb index f78a9d5b..ec009751 100644 --- a/scripts/experiments/likelihood_debug/scene_parse.ipynb +++ b/scripts/experiments/likelihood_debug/scene_parse.ipynb @@ -55,11 +55,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=200.0, fy=200.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.0001, far=2.0\n", + " height=100, width=100, fx=200.0, fy=200.0, cx=50.0, cy=50.0, near=0.0001, far=2.0\n", ")" ] }, @@ -85,15 +81,19 @@ } ], "source": [ - "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/10.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -108,25 +108,29 @@ "num_position_grids = 51\n", "num_angle_grids = 51\n", "contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(\n", - " -width, -width, -ang,\n", - " width, width, ang,\n", - " num_position_grids,num_position_grids,num_angle_grids\n", + " -width,\n", + " -width,\n", + " -ang,\n", + " width,\n", + " width,\n", + " ang,\n", + " num_position_grids,\n", + " num_position_grids,\n", + " num_angle_grids,\n", ")\n", "\n", "grid_params = [\n", - " (0.3, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),\n", - " (0.05, jnp.pi/3, (15,15,15)),\n", - " (0.02, jnp.pi, (9,9,51))\n", - " , (0.01, jnp.pi/5, (15,15,15)),\n", + " (0.3, jnp.pi, (15, 15, 15)),\n", + " (0.2, jnp.pi, (15, 15, 15)),\n", + " (0.1, jnp.pi, (15, 15, 15)),\n", + " (0.05, jnp.pi / 3, (15, 15, 15)),\n", + " (0.02, jnp.pi, (9, 9, 51)),\n", + " (0.01, jnp.pi / 5, (15, 15, 15)),\n", " # (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]" ] }, @@ -139,7 +143,7 @@ "source": [ "table_pose = b.t3d.inverse_pose(\n", " b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.6, .35]),\n", + " jnp.array([0.0, 0.6, 0.35]),\n", " jnp.array([0.0, 0.0, 0.0]),\n", " jnp.array([0.0, 0.0, 1.0]),\n", " )\n", @@ -154,36 +158,56 @@ "metadata": {}, "outputs": [], "source": [ - "_cp_to_pose = lambda cp, index: table_pose@ b.scene_graph.relative_pose_from_edge(cp, face_child, b.RENDERER.model_box_dims[index])\n", - "cps_to_pose= jax.vmap(_cp_to_pose, in_axes=(0,0,))\n", + "_cp_to_pose = lambda cp, index: table_pose @ b.scene_graph.relative_pose_from_edge(\n", + " cp, face_child, b.RENDERER.model_box_dims[index]\n", + ")\n", + "cps_to_pose = jax.vmap(\n", + " _cp_to_pose,\n", + " in_axes=(\n", + " 0,\n", + " 0,\n", + " ),\n", + ")\n", "cps_to_pose_jit = jax.jit(cps_to_pose)\n", - "cps_to_pose_parallel = jax.vmap(cps_to_pose, in_axes=(0,None,))\n", + "cps_to_pose_parallel = jax.vmap(\n", + " cps_to_pose,\n", + " in_axes=(\n", + " 0,\n", + " None,\n", + " ),\n", + ")\n", "cps_to_pose_parallel_jit = jax.jit(cps_to_pose_parallel)\n", "\n", "key = jax.random.PRNGKey(30)\n", "\n", + "\n", "def score_images(rendered, observed):\n", " return -jnp.linalg.norm(observed - rendered, axis=-1).mean()\n", "\n", + "\n", "# def score_images(rendered, observed):\n", "# mask = observed[...,2] < intrinsics.far\n", "# return (jnp.linalg.norm(observed - rendered, axis=-1)* (1.0 * mask)).sum() / mask.sum()\n", "\n", + "\n", "def score_images(rendered, observed):\n", " return -jnp.linalg.norm(observed - rendered, axis=-1).mean()\n", "\n", + "\n", "# def score_images(rendered, observed):\n", "# distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", "# probabilities_per_pixel = jax.scipy.stats.norm.pdf(\n", "# distances,\n", - "# loc=0.0, \n", + "# loc=0.0,\n", "# scale=0.02\n", "# )\n", "# image_probability = probabilities_per_pixel.mean()\n", "# return image_probability\n", "\n", + "\n", "def score_images(rendered, observed):\n", - " return b.threedp3_likelihood(observed, rendered, 0.02, 0.0)\n", + " return b.threedp3_likelihood(observed, rendered, 0.02, 0.0)\n", + "\n", "\n", "score_vmap = jax.jit(jax.vmap(score_images, in_axes=(0, None)))" ] @@ -206,7 +230,7 @@ } ], "source": [ - "cps_to_pose_parallel_jit(jnp.zeros((1000,1,3)), jnp.array([0])).shape" + "cps_to_pose_parallel_jit(jnp.zeros((1000, 1, 3)), jnp.array([0])).shape" ] }, { @@ -217,10 +241,10 @@ "outputs": [], "source": [ "def grid_and_max(cps, indices, number, grid, obs_img):\n", - " cps_expanded = jnp.repeat(cps[None,...], grid.shape[0], axis=0)\n", - " cps_expanded = cps_expanded.at[:,number,:].set(cps_expanded[:,number,:] + grid)\n", + " cps_expanded = jnp.repeat(cps[None, ...], grid.shape[0], axis=0)\n", + " cps_expanded = cps_expanded.at[:, number, :].set(cps_expanded[:, number, :] + grid)\n", " cp_poses = cps_to_pose_parallel(cps_expanded, indices)\n", - " scores = score_vmap(b.RENDERER.render_many(cp_poses, indices)[...,:3], obs_img)\n", + " scores = score_vmap(b.RENDERER.render_many(cp_poses, indices)[..., :3], obs_img)\n", " best_idx = jnp.argmax(scores)\n", " cps = cps_expanded[best_idx]\n", " return cps, scores[best_idx]" @@ -233,12 +257,18 @@ "metadata": {}, "outputs": [], "source": [ - "def c2f(potential_cps, potential_indices, number, contact_param_gridding_schedule, obs_img):\n", + "def c2f(\n", + " potential_cps, potential_indices, number, contact_param_gridding_schedule, obs_img\n", + "):\n", " # cps_over_inference = [potential_cps]\n", " for cp_grid in contact_param_gridding_schedule:\n", - " potential_cps, score = grid_and_max(potential_cps, potential_indices, number, cp_grid, obs_img)\n", + " potential_cps, score = grid_and_max(\n", + " potential_cps, potential_indices, number, cp_grid, obs_img\n", + " )\n", " # cps_over_inference.append(potential_cps)\n", " return potential_cps, score\n", + "\n", + "\n", "c2f_jit = jax.jit(c2f)" ] }, @@ -326,20 +356,22 @@ ], "source": [ "for _ in tqdm(range(50)):\n", - " key = jax.random.split(key,2)[0]\n", + " key = jax.random.split(key, 2)[0]\n", " # gt_indices = \"jnp.array([13,11])\n", " gt_indices = jax.random.choice(key, 21, shape=(3,))\n", " # key = jnp.array([2755247810, 1586593754], dtype=np.uint32)\n", " low, high = jnp.array([-0.25, -0.25, -jnp.pi]), jnp.array([0.25, 0.25, jnp.pi])\n", - " gt_cps = jax.random.uniform(key, shape=(gt_indices.shape[0],3),minval=low, maxval=high)\n", + " gt_cps = jax.random.uniform(\n", + " key, shape=(gt_indices.shape[0], 3), minval=low, maxval=high\n", + " )\n", " gt_poses = cps_to_pose(gt_cps, gt_indices)\n", - " obs_img = b.RENDERER.render(gt_poses, gt_indices)[...,:3]\n", + " obs_img = b.RENDERER.render(gt_poses, gt_indices)[..., :3]\n", "\n", - " b.get_depth_image(obs_img[...,2])\n", + " b.get_depth_image(obs_img[..., 2])\n", "\n", " inference_results = []\n", " for _ in range(2):\n", - " cps = jnp.zeros((0,3))\n", + " cps = jnp.zeros((0, 3))\n", " indices = jnp.array([], dtype=jnp.int32)\n", "\n", " for _ in range(3):\n", @@ -347,11 +379,30 @@ " best_index = -1\n", " best_cps = None\n", " best_indices = None\n", - " key = jax.random.split(key,2)[0]\n", + " key = jax.random.split(key, 2)[0]\n", " for next_index in range(21):\n", " potential_indices = jnp.concatenate([indices, jnp.array([next_index])])\n", - " potential_cps = jnp.concatenate([cps, jax.random.uniform(key, shape=(1,3,),minval=low, maxval=high)])\n", - " potential_cps, score = c2f_jit(potential_cps, potential_indices, len(potential_indices) - 1, contact_param_gridding_schedule, obs_img)\n", + " potential_cps = jnp.concatenate(\n", + " [\n", + " cps,\n", + " jax.random.uniform(\n", + " key,\n", + " shape=(\n", + " 1,\n", + " 3,\n", + " ),\n", + " minval=low,\n", + " maxval=high,\n", + " ),\n", + " ]\n", + " )\n", + " potential_cps, score = c2f_jit(\n", + " potential_cps,\n", + " potential_indices,\n", + " len(potential_indices) - 1,\n", + " contact_param_gridding_schedule,\n", + " obs_img,\n", + " )\n", " if score > best_score:\n", " best_index = next_index\n", " best_score = score\n", @@ -361,13 +412,20 @@ " # print(\"Inferred CP: \", potential_cps)\n", " cps = best_cps\n", " indices = best_indices\n", - " reconstruction = b.RENDERER.render(cps_to_pose(best_cps, best_indices), best_indices)[...,:3]\n", - " b.hstack_images([b.get_depth_image(obs_img[...,2]), *[b.get_depth_image(reconstruction[...,2])]])\n", + " reconstruction = b.RENDERER.render(\n", + " cps_to_pose(best_cps, best_indices), best_indices\n", + " )[..., :3]\n", + " b.hstack_images(\n", + " [\n", + " b.get_depth_image(obs_img[..., 2]),\n", + " *[b.get_depth_image(reconstruction[..., 2])],\n", + " ]\n", + " )\n", " print(\"GT Indices : \", gt_indices)\n", " print(\"Inf Indices : \", indices)\n", "\n", " inference_results.append((cps, indices))\n", - " data_dump.append((gt_cps, gt_indices, inference_results))\n" + " data_dump.append((gt_cps, gt_indices, inference_results))" ] }, { @@ -388,21 +446,28 @@ "images = []\n", "for i in tqdm(range(len(data_dump))):\n", " (gt_cps, gt_indices, inference_results) = data_dump[i]\n", - " obs_img = b.RENDERER.render(cps_to_pose(gt_cps, gt_indices), gt_indices)[...,:3]\n", - " obs_viz = b.scale_image(b.get_depth_image(obs_img[...,2]), 3.0)\n", + " obs_img = b.RENDERER.render(cps_to_pose(gt_cps, gt_indices), gt_indices)[..., :3]\n", + " obs_viz = b.scale_image(b.get_depth_image(obs_img[..., 2]), 3.0)\n", "\n", " viz = []\n", " for j in range(len(inference_results)):\n", " best_cps, best_indices = inference_results[j]\n", - " reconstruction = b.RENDERER.render(cps_to_pose(best_cps, best_indices), best_indices)[...,:3]\n", - " viz.append(b.scale_image(b.get_depth_image(reconstruction[...,2]), 3.0))\n", - " \n", - " images.append(b.multi_panel([obs_viz, *viz], labels=[\"Observed\", \"Inference Run 1\", \"Inference Run 2\"]))" + " reconstruction = b.RENDERER.render(\n", + " cps_to_pose(best_cps, best_indices), best_indices\n", + " )[..., :3]\n", + " viz.append(b.scale_image(b.get_depth_image(reconstruction[..., 2]), 3.0))\n", + "\n", + " images.append(\n", + " b.multi_panel(\n", + " [obs_viz, *viz], labels=[\"Observed\", \"Inference Run 1\", \"Inference Run 2\"]\n", + " )\n", + " )" ] }, { "cell_type": "code", "execution_count": 16, + "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "outputs": [ { @@ -493,7 +558,7 @@ "fps = 0.5\n", "for i in range(len(images)):\n", " images[i].convert(\"RGB\").save(\"%07d.png\" % i)\n", - "subprocess.call([\"ffmpeg\",\"-y\",\"-r\",str(fps),\"-i\", \"%07d.png\",\"out.mp4\"])" + "subprocess.call([\"ffmpeg\", \"-y\", \"-r\", str(fps), \"-i\", \"%07d.png\", \"out.mp4\"])" ] }, { @@ -519,17 +584,18 @@ "ax = fig.add_subplot(131)\n", "i = 1\n", "(gt_cps, gt_indices, inference_results) = data_dump[i]\n", - "obs_img = b.RENDERER.render(cps_to_pose(gt_cps, gt_indices), gt_indices)[...,:3]\n", - "img_panels.append(b.add_depth_image(ax,obs_img[...,2]))\n", + "obs_img = b.RENDERER.render(cps_to_pose(gt_cps, gt_indices), gt_indices)[..., :3]\n", + "img_panels.append(b.add_depth_image(ax, obs_img[..., 2]))\n", "ax.set_title(\"Observed\")\n", "\n", "for j in range(len(inference_results)):\n", - " ax = fig.add_subplot(1,3,j+2)\n", + " ax = fig.add_subplot(1, 3, j + 2)\n", " best_cps, best_indices = inference_results[j]\n", - " reconstruction = b.RENDERER.render(cps_to_pose(best_cps, best_indices), best_indices)[...,:3]\n", - " img_panels.append(b.add_depth_image(ax,reconstruction[...,2]))\n", - " ax.set_title(\"Run {}\".format(j+1))\n", - "\n" + " reconstruction = b.RENDERER.render(\n", + " cps_to_pose(best_cps, best_indices), best_indices\n", + " )[..., :3]\n", + " img_panels.append(b.add_depth_image(ax, reconstruction[..., 2]))\n", + " ax.set_title(\"Run {}\".format(j + 1))" ] }, { @@ -557,8 +623,8 @@ "source": [ "i = 1\n", "(gt_cps, gt_indices, inference_results) = data_dump[i]\n", - "obs_img = b.RENDERER.render(cps_to_pose(gt_cps, gt_indices), gt_indices)[...,:3]\n", - "img_panels[0].set_array(b.preprocess_for_viz(obs_img[...,2]))\n", + "obs_img = b.RENDERER.render(cps_to_pose(gt_cps, gt_indices), gt_indices)[..., :3]\n", + "img_panels[0].set_array(b.preprocess_for_viz(obs_img[..., 2]))\n", "\n", "# for j in range(len(inference_results)):\n", "# ax = fig.add_subplot(1,3,j+2)\n", @@ -607,7 +673,7 @@ "metadata": {}, "outputs": [], "source": [ - "cps = jnp.zeros((0,3))\n", + "cps = jnp.zeros((0, 3))\n", "indices = jnp.array([], dtype=jnp.int32)" ] }, @@ -735,7 +801,7 @@ } ], "source": [ - "b.RENDERER.render_many(jnp.zeros((30000,1,4,4)), jnp.array([0])).shape" + "b.RENDERER.render_many(jnp.zeros((30000, 1, 4, 4)), jnp.array([0])).shape" ] }, { @@ -801,23 +867,38 @@ "%%time\n", "indices = jnp.array([11])\n", "\n", - "key = jax.random.split(key,2)[0]\n", - "cps = jax.random.uniform(key, shape=(len(indices),3,),minval=low, maxval=high)\n", + "key = jax.random.split(key, 2)[0]\n", + "cps = jax.random.uniform(\n", + " key,\n", + " shape=(\n", + " len(indices),\n", + " 3,\n", + " ),\n", + " minval=low,\n", + " maxval=high,\n", + ")\n", "cps_over_inference = [cps]\n", "poses = [cps_to_pose(cps)]\n", "for cp_grid in contact_param_gridding_schedule:\n", - " contact_param_grid = cps + cp_grid[:,None,...]\n", + " contact_param_grid = cps + cp_grid[:, None, ...]\n", " cp_poses = cps_to_pose_parallel(contact_param_grid)\n", - " scores = score_vmap(b.RENDERER.render_many(cp_poses, indices)[...,:3], obs_img)\n", + " scores = score_vmap(b.RENDERER.render_many(cp_poses, indices)[..., :3], obs_img)\n", " best_idx = jnp.argmax(scores)\n", " cps = contact_param_grid[best_idx]\n", " cps_over_inference.append(cps)\n", " poses.append(cp_poses[best_idx])\n", - "print(score_images(b.RENDERER.render(cps_to_pose(cps), indices)[...,:3], obs_img))\n", + "print(score_images(b.RENDERER.render(cps_to_pose(cps), indices)[..., :3], obs_img))\n", "print(\"GT CP: \", gt_cps)\n", "print(\"Inferred CP: \", cps)\n", - "images_over_time = b.RENDERER.render_many(cps_to_pose_parallel(jnp.stack(cps_over_inference))[...], indices)[...,:3]\n", - "b.hstack_images([b.get_depth_image(obs_img[...,2]), *[b.get_depth_image(i[...,2]) for i in images_over_time]])" + "images_over_time = b.RENDERER.render_many(\n", + " cps_to_pose_parallel(jnp.stack(cps_over_inference))[...], indices\n", + ")[..., :3]\n", + "b.hstack_images(\n", + " [\n", + " b.get_depth_image(obs_img[..., 2]),\n", + " *[b.get_depth_image(i[..., 2]) for i in images_over_time],\n", + " ]\n", + ")" ] }, { @@ -862,23 +943,27 @@ "all_all_paths = []\n", "for _ in range(3):\n", " all_paths = []\n", - " for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)):\n", + " for obj_id in tqdm(range(len(b.RENDERER.meshes) - 1)):\n", " path = []\n", - " trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)\n", + " trace_ = add_object_jit(trace, key, obj_id, 0, 2, 3)\n", " number = b.get_contact_params(trace_).shape[0] - 1\n", " path.append(trace_)\n", " for c2f_iter in range(len(contact_param_gridding_schedule)):\n", - " trace_ = c2f_contact_update_jit(trace_, key, number,\n", - " contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)\n", + " trace_ = c2f_contact_update_jit(\n", + " trace_,\n", + " key,\n", + " number,\n", + " contact_param_gridding_schedule[c2f_iter],\n", + " V_GRID,\n", + " O_GRID,\n", + " )\n", " path.append(trace_)\n", " # for c2f_iter in range(len(contact_param_gridding_schedule)):\n", " # trace_ = c2f_contact_update_jit(trace_, key, number,\n", " # contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID)\n", - " all_paths.append(\n", - " path\n", - " )\n", + " all_paths.append(path)\n", " all_all_paths.append(all_paths)\n", - " \n", + "\n", " scores = jnp.array([t[-1].get_score() for t in all_paths])\n", " print(scores)\n", " normalized_scores = b.utils.normalize_log_scores(scores)\n", @@ -915,10 +1000,18 @@ ], "source": [ "contact_param_grid = gt_cp + contact_param_deltas\n", - "scores = jnp.concatenate([\n", - " score_vmap(b.RENDERER.render_many(cp_to_pose_parallel(cps)[:,None,...], jnp.array([13]))[...,:3], obs_img)\n", - " for cps in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)\n", + "scores = jnp.concatenate(\n", + " [\n", + " score_vmap(\n", + " b.RENDERER.render_many(\n", + " cp_to_pose_parallel(cps)[:, None, ...], jnp.array([13])\n", + " )[..., :3],\n", + " obs_img,\n", + " )\n", + " for cps in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")\n", "\n", "sort_order = jnp.argsort(-scores)\n", "sorted_scores = scores[sort_order]\n", @@ -926,10 +1019,12 @@ "print(\"GT CP: \", gt_cp)\n", "print(sorted_scores[:k])\n", "# print(contact_param_grid[sort_order[:k]])\n", - "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:,None,...]\n", - "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[...,:3]\n", + "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:, None, ...]\n", + "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[..., :3]\n", "\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -991,35 +1086,39 @@ ], "source": [ "r = 0.003\n", - "fig, ax = plt.subplots(1,1,figsize=(9,9))\n", - "ax.spines['top'].set_visible(False)\n", - "ax.spines['right'].set_visible(False)\n", + "fig, ax = plt.subplots(1, 1, figsize=(9, 9))\n", + "ax.spines[\"top\"].set_visible(False)\n", + "ax.spines[\"right\"].set_visible(False)\n", "\n", "from matplotlib.collections import LineCollection\n", + "\n", + "\n", "def line_collection(a, b, c=None, linewidth=1, **kwargs):\n", " lines = np.column_stack((a, b)).reshape(-1, 2, 2)\n", " lc = LineCollection(lines, colors=c, linewidths=linewidth, **kwargs)\n", " return lc\n", "\n", "\n", - "zorder=None\n", - "linewidth=2\n", + "zorder = None\n", + "linewidth = 2\n", "\n", "ps = jnp.array(contact_param_grid)\n", "sc = jnp.array(scores)\n", "\n", - "def unit_vec(hd): \n", + "\n", + "def unit_vec(hd):\n", " return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n", "\n", - "a = ps[:,:2]\n", - "bs = a + r * jax.vmap(unit_vec)(ps[:,2])\n", - "q=0.0\n", "\n", - "clip=-1e12\n", - "cmap=\"viridis\"\n", + "a = ps[:, :2]\n", + "bs = a + r * jax.vmap(unit_vec)(ps[:, 2])\n", + "q = 0.0\n", + "\n", + "clip = -1e12\n", + "cmap = \"viridis\"\n", "sc = sc.reshape(-1)\n", - "sc = jnp.where(jnp==-jnp.inf, clip, sc)\n", - "sc = jnp.clip(sc, clip, jnp.max(sc))\n", + "sc = jnp.where(jnp == -jnp.inf, clip, sc)\n", + "sc = jnp.clip(sc, clip, jnp.max(sc))\n", "sc = jnp.clip(sc, jnp.quantile(sc, q), jnp.max(sc))\n", "cs = getattr(plt.cm, cmap)(plt.Normalize()(sc))\n", "\n", @@ -1028,8 +1127,8 @@ "bs = bs[order]\n", "cs = cs[order]\n", "\n", - "print(a,b)\n", - "ax.add_collection(line_collection(a,bs, c=cs, zorder=zorder, linewidth=linewidth))\n", + "print(a, b)\n", + "ax.add_collection(line_collection(a, bs, c=cs, zorder=zorder, linewidth=linewidth))\n", "\n", "p = gt_cp\n", "\n", @@ -1037,13 +1136,22 @@ "for i in sort_order[:10]:\n", " p = contact_param_grid[i]\n", " a = p[:2]\n", - " bs = a + r*unit_vec(p[2])\n", - " im = ax.plot([a[0],bs[0]],[a[1],bs[1]], c=\"blue\", zorder=zorder, linewidth=linewidth*2, alpha=0.5)\n", + " bs = a + r * unit_vec(p[2])\n", + " im = ax.plot(\n", + " [a[0], bs[0]],\n", + " [a[1], bs[1]],\n", + " c=\"blue\",\n", + " zorder=zorder,\n", + " linewidth=linewidth * 2,\n", + " alpha=0.5,\n", + " )\n", "\n", "p = gt_cp\n", "a = p[:2]\n", - "bs = a + r*unit_vec(p[2])\n", - "im = ax.plot([a[0],bs[0]],[a[1],bs[1]], c=\"red\", zorder=zorder, linewidth=linewidth, alpha=0.5)\n", + "bs = a + r * unit_vec(p[2])\n", + "im = ax.plot(\n", + " [a[0], bs[0]], [a[1], bs[1]], c=\"red\", zorder=zorder, linewidth=linewidth, alpha=0.5\n", + ")\n", "\n", "# p = ps[scores.argmax()]\n", "# a = p[:2]\n", @@ -1052,8 +1160,8 @@ "\n", "\n", "border = 0.01\n", - "ax.set_xlim(ps.min(0)[0]-border, ps.max(0)[0]+border)\n", - "ax.set_ylim(ps.min(0)[1]-border, ps.max(0)[1]+border)\n", + "ax.set_xlim(ps.min(0)[0] - border, ps.max(0)[0] + border)\n", + "ax.set_ylim(ps.min(0)[1] - border, ps.max(0)[1] + border)\n", "# plt.colorbar(im,cax=ax)" ] }, @@ -1100,8 +1208,8 @@ } ], "source": [ - "observed=obs_img\n", - "rendered=rendered_top_k[1]\n", + "observed = obs_img\n", + "rendered = rendered_top_k[1]\n", "distances = jnp.linalg.norm(observed - rendered, axis=-1)\n", "matches = distances < 0.01\n", "fig = plt.figure()\n", @@ -1109,10 +1217,10 @@ "print(matches.sum())\n", "ax.imshow(matches)\n", "ax = fig.add_subplot(132)\n", - "ax.imshow(b.viz.preprocess_for_viz(observed[...,2]))\n", + "ax.imshow(b.viz.preprocess_for_viz(observed[..., 2]))\n", "ax = fig.add_subplot(133)\n", - "ax.imshow(b.viz.preprocess_for_viz(observed[...,2]),alpha=0.5)\n", - "ax.imshow(matches,alpha=0.5)" + "ax.imshow(b.viz.preprocess_for_viz(observed[..., 2]), alpha=0.5)\n", + "ax.imshow(matches, alpha=0.5)" ] }, { @@ -1140,23 +1248,28 @@ "for cp_grid in contact_param_gridding_schedule:\n", " contact_param_grid = cp + cp_grid\n", " cp_poses = cp_to_pose_parallel(contact_param_grid)\n", - " scores = score_vmap(b.RENDERER.render_many(cp_poses[:,None,...], jnp.array([13]))[...,:3], obs_img)\n", + " scores = score_vmap(\n", + " b.RENDERER.render_many(cp_poses[:, None, ...], jnp.array([13]))[..., :3],\n", + " obs_img,\n", + " )\n", " best_idx = jnp.argmax(scores)\n", " cp = contact_param_grid[best_idx]\n", " cps.append(cp)\n", " poses.append(cp_poses[best_idx])\n", - "images_over_time = b.RENDERER.render_many(cp_to_pose_parallel(jnp.stack(cps))[:,None,...], jnp.array([13]))[...,:3]\n", + "images_over_time = b.RENDERER.render_many(\n", + " cp_to_pose_parallel(jnp.stack(cps))[:, None, ...], jnp.array([13])\n", + ")[..., :3]\n", "\n", "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", "colors = b.distinct_colors(len(images_over_time))\n", "b.show_trimesh(f\"gt\", b.RENDERER.meshes[13], opacity=0.5, color=b.RED)\n", "b.set_pose(f\"gt\", gt_pose)\n", "for i in range(len(images_over_time)):\n", " b.show_trimesh(f\"_{i}\", b.RENDERER.meshes[13], opacity=0.5, color=colors[i])\n", " b.set_pose(f\"_{i}\", poses[i])\n", - " \n", - "b.hstack_images([b.get_depth_image(i[...,2]) for i in images_over_time])" + "\n", + "b.hstack_images([b.get_depth_image(i[..., 2]) for i in images_over_time])" ] }, { @@ -1179,10 +1292,18 @@ ], "source": [ "contact_param_grid = cp + contact_param_deltas\n", - "scores = jnp.concatenate([\n", - " score_vmap(b.RENDERER.render_many(cp_to_pose_parallel(cps)[:,None,...], jnp.array([13]))[...,:3], obs_img)\n", - " for cps in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)\n", + "scores = jnp.concatenate(\n", + " [\n", + " score_vmap(\n", + " b.RENDERER.render_many(\n", + " cp_to_pose_parallel(cps)[:, None, ...], jnp.array([13])\n", + " )[..., :3],\n", + " obs_img,\n", + " )\n", + " for cps in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")\n", "\n", "sort_order = jnp.argsort(-scores)\n", "sorted_scores = scores[sort_order]\n", @@ -1190,11 +1311,13 @@ "# print(\"GT CP: \", gt_cp)\n", "# print(sorted_scores[:k])\n", "# print(contact_param_grid[sort_order[:k]])\n", - "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:,None,...]\n", - "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[...,:3]\n", + "poses = cp_to_pose_parallel(contact_param_grid[sort_order[:k]])[:, None, ...]\n", + "rendered_top_k = b.RENDERER.render_many(poses, jnp.array([13]))[..., :3]\n", "\n", "\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -1255,9 +1378,11 @@ ], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", - "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1,3), color=b.RED)\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", + "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1, 3), color=b.RED)\n", + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -1280,9 +1405,11 @@ ], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", - "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1,3), color=b.RED)\n", - "b.viz.scale_image(b.hstack_images([b.get_depth_image(i[...,2]) for i in rendered_top_k]),3.0)" + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", + "b.show_cloud(\"3\", rendered_top_k[1].reshape(-1, 3), color=b.RED)\n", + "b.viz.scale_image(\n", + " b.hstack_images([b.get_depth_image(i[..., 2]) for i in rendered_top_k]), 3.0\n", + ")" ] }, { @@ -1293,10 +1420,10 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", obs_img.reshape(-1,3))\n", + "b.show_cloud(\"1\", obs_img.reshape(-1, 3))\n", "colors = b.distinct_colors(len(images_over_time))\n", "for i in range(len(images_over_time)):\n", - " b.show_cloud(f\"_{i}\", images_over_time[i].reshape(-1,3), color=colors[i])" + " b.show_cloud(f\"_{i}\", images_over_time[i].reshape(-1, 3), color=colors[i])" ] }, { @@ -1481,10 +1608,13 @@ ], "source": [ "contact_param_grid = trace[\"contact_params_1\"] + contact_param_deltas\n", - "images = jnp.concatenate([\n", - " enumerators.enumerate_choices(trace, key, cp)[\"image\"]\n", - " for cp in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)" + "images = jnp.concatenate(\n", + " [\n", + " enumerators.enumerate_choices(trace, key, cp)[\"image\"]\n", + " for cp in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")" ] }, { @@ -1616,7 +1746,9 @@ } ], "source": [ - "traces = enumerators.enumerate_choices_get_scores(trace, key, contact_param_deltas + trace[\"contact_params_1\"])" + "traces = enumerators.enumerate_choices_get_scores(\n", + " trace, key, contact_param_deltas + trace[\"contact_params_1\"]\n", + ")" ] }, { @@ -1667,27 +1799,45 @@ " # fig.suptitle(f\"Variance: {variance} Outlier Prob: {outlier_prob}\")\n", " widths = [1, 1]\n", " heights = [2]\n", - " spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", - " \n", + " spec = fig.add_gridspec(\n", + " ncols=2, nrows=1, width_ratios=widths, height_ratios=heights\n", + " )\n", + "\n", " ax = fig.add_subplot(spec[0, 0])\n", - " ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))\n", + " ax.imshow(jnp.array(b.get_depth_image(observation[..., 2], max=1.4)))\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.set_title(f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\")\n", + " ax.set_title(\n", + " f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\"\n", + " )\n", " # ax.set_title(f\"Observed Depth\")\n", - " \n", - " \n", + "\n", " ax = fig.add_subplot(spec[0, 1])\n", " ax.set_aspect(1.0)\n", - " circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle=\"--\", linewidth=0.5)\n", + " circ = plt.Circle(\n", + " (0, 0),\n", + " radius=1,\n", + " edgecolor=\"black\",\n", + " facecolor=\"None\",\n", + " linestyle=\"--\",\n", + " linewidth=0.5,\n", + " )\n", " ax.add_patch(circ)\n", " ax.set_xlim(-1.1, 1.1)\n", " ax.set_ylim(-1.1, 1.1)\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.scatter(-jnp.sin(sampled_contacts[:,2]),jnp.cos(sampled_contacts[:,2]), color='red',label=\"Posterior Samples\", alpha=0.5, s=30)\n", - " ax.scatter(-jnp.sin(gt_contact[2]),jnp.cos(gt_contact[2]), label=\"Actual\", alpha=0.9, s=25)\n", + " ax.scatter(\n", + " -jnp.sin(sampled_contacts[:, 2]),\n", + " jnp.cos(sampled_contacts[:, 2]),\n", + " color=\"red\",\n", + " label=\"Posterior Samples\",\n", + " alpha=0.5,\n", + " s=30,\n", + " )\n", + " ax.scatter(\n", + " -jnp.sin(gt_contact[2]), jnp.cos(gt_contact[2]), label=\"Actual\", alpha=0.9, s=25\n", + " )\n", " ax.set_title(\"Posterior on Orientation (top view)\")\n", " # ax.legend(fontsize=9)\n", " # plt.show()\n", @@ -1705,10 +1855,9 @@ " contact_param_grid = contact_param_deltas + trace_[f\"contact_params_1\"]\n", " scores = enumerators.enumerate_choices_get_scores(trace_, key, contact_param_grid)\n", " i = scores.argmax()\n", - " return enumerators.update_choices(\n", - " trace_, key,\n", - " contact_param_grid[i]\n", - " )\n", + " return enumerators.update_choices(trace_, key, contact_param_grid[i])\n", + "\n", + "\n", "c2f_contact_update_jit = jax.jit(c2f_contact_update)" ] }, @@ -1731,16 +1880,18 @@ "outputs": [], "source": [ "grid_params = [\n", - " (0.3, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),\n", - " (0.05, jnp.pi/3, (15,15,15)), (0.02, jnp.pi, (9,9,51)), (0.01, jnp.pi/5, (15,15,15)), (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))\n", + " (0.3, jnp.pi, (15, 15, 15)),\n", + " (0.2, jnp.pi, (15, 15, 15)),\n", + " (0.1, jnp.pi, (15, 15, 15)),\n", + " (0.05, jnp.pi / 3, (15, 15, 15)),\n", + " (0.02, jnp.pi, (9, 9, 51)),\n", + " (0.01, jnp.pi / 5, (15, 15, 15)),\n", + " (0.01, 0.0, (31, 31, 1)),\n", + " (0.05, 0.0, (31, 31, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]" ] }, @@ -1751,7 +1902,7 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.split(key,2)[0]" + "key = jax.random.split(key, 2)[0]" ] }, { @@ -1781,24 +1932,36 @@ ], "source": [ "low, high = jnp.array([-0.2, -0.2, -jnp.pi]), jnp.array([0.2, 0.2, jnp.pi])\n", - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"id_1\": jnp.int32(13),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " \"face_parent_1\": 3,\n", - " \"face_child_1\": 2,\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jax.random.uniform(key, shape=(3,),minval=low, maxval=high)\n", - "}), (\n", - " jnp.arange(2),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.5, -0.5, -2*jnp.pi]), jnp.array([0.5, 0.5, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, 1.0, intrinsics.fx)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"id_1\": jnp.int32(13),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " \"face_parent_1\": 3,\n", + " \"face_child_1\": 2,\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jax.random.uniform(\n", + " key, shape=(3,), minval=low, maxval=high\n", + " ),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(2),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.5, -0.5, -2 * jnp.pi]), jnp.array([0.5, 0.5, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " 1.0,\n", + " intrinsics.fx,\n", + " ),\n", ")\n", "gt_poses = b.get_poses(trace)\n", "gt_contact = trace[\"contact_params_1\"]\n", @@ -1839,7 +2002,9 @@ "path = []\n", "path.append(trace)\n", "for c2f_iter in range(len(contact_param_gridding_schedule)):\n", - " trace = c2f_contact_update_jit(trace, key, contact_param_gridding_schedule[c2f_iter])\n", + " trace = c2f_contact_update_jit(\n", + " trace, key, contact_param_gridding_schedule[c2f_iter]\n", + " )\n", " path.append(trace)\n", "print(trace[\"contact_params_1\"])\n", "b.viz_trace_rendered_observed(trace)" @@ -1863,10 +2028,13 @@ "source": [ "%%time\n", "contact_param_grid = trace[\"contact_params_1\"] + contact_param_deltas\n", - "weights = jnp.concatenate([\n", - " enumerators.enumerate_choices_get_scores(trace, key, cp)\n", - " for cp in jnp.array_split(contact_param_grid, 15)\n", - "],axis=0)" + "weights = jnp.concatenate(\n", + " [\n", + " enumerators.enumerate_choices_get_scores(trace, key, cp)\n", + " for cp in jnp.array_split(contact_param_grid, 15)\n", + " ],\n", + " axis=0,\n", + ")" ] }, { @@ -1878,7 +2046,9 @@ "source": [ "key2 = jax.random.split(key, 1)[0]\n", "normalized_weights = b.utils.normalize_log_scores(weights)\n", - "sampled_indices = jax.random.choice(key2,jnp.arange(normalized_weights.shape[0]), shape=(2000,), p=normalized_weights)\n", + "sampled_indices = jax.random.choice(\n", + " key2, jnp.arange(normalized_weights.shape[0]), shape=(2000,), p=normalized_weights\n", + ")\n", "sampled_contact_params = contact_param_grid[sampled_indices]" ] }, @@ -1928,21 +2098,21 @@ ], "source": [ "fig = plt.figure()\n", - "ax = fig.add_subplot(projection='3d')\n", + "ax = fig.add_subplot(projection=\"3d\")\n", "ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n", "# make the grid lines transparent\n", - "ax.xaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "ax.yaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "ax.zaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", - "u, v = np.mgrid[0:2*np.pi:21j, 0:np.pi:11j]\n", - "x = np.cos(u)*np.sin(v)\n", - "y = np.sin(u)*np.sin(v)\n", + "ax.xaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "ax.yaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "ax.zaxis._axinfo[\"grid\"][\"color\"] = (1, 1, 1, 0)\n", + "u, v = np.mgrid[0 : 2 * np.pi : 21j, 0 : np.pi : 11j]\n", + "x = np.cos(u) * np.sin(v)\n", + "y = np.sin(u) * np.sin(v)\n", "z = np.cos(v)\n", - "ax.axes.set_xlim3d(-1.1, 1.1) \n", - "ax.axes.set_ylim3d(-1.1, 1.1) \n", - "ax.axes.set_zlim3d(-1.1, 1.1) \n", + "ax.axes.set_xlim3d(-1.1, 1.1)\n", + "ax.axes.set_ylim3d(-1.1, 1.1)\n", + "ax.axes.set_zlim3d(-1.1, 1.1)\n", "ax.set_aspect(\"equal\")\n", "ax.plot_wireframe(x, y, z, color=(0.0, 0.0, 0.0, 0.3), linewidths=0.5)\n", "\n", @@ -1953,15 +2123,25 @@ "\n", "points = []\n", "NUM = 1\n", - "offset = jnp.pi/2\n", + "offset = jnp.pi / 2\n", "scaling = 0.96\n", "for i in sampled_contact_params:\n", - " points.append(np.array([np.cos(i[2] + offset) * scaling, np.sin(i[2] + offset) * scaling,0.0]))\n", + " points.append(\n", + " np.array(\n", + " [np.cos(i[2] + offset) * scaling, np.sin(i[2] + offset) * scaling, 0.0]\n", + " )\n", + " )\n", "points = np.array(points)\n", "\n", "z = 0.1\n", - "for i in np.arange(.1,1.01,.1):\n", - " ax.scatter(points[:,0], points[:,1],points[:,2], s=(40*i*(z*.9+.1))**2, color=(1,0,0,.3/i/10))\n", + "for i in np.arange(0.1, 1.01, 0.1):\n", + " ax.scatter(\n", + " points[:, 0],\n", + " points[:, 1],\n", + " points[:, 2],\n", + " s=(40 * i * (z * 0.9 + 0.1)) ** 2,\n", + " color=(1, 0, 0, 0.3 / i / 10),\n", + " )\n", "# offset = jnp.pi/2\n", "# angle = jnp.pi/4 - jnp.pi/4 - jnp.pi/4 - jnp.pi/4\n", "# for i in np.arange(.1,1.01,.1):\n", @@ -1981,13 +2161,18 @@ "scaled_up_intrinsics = b.scale_camera_parameters(intrinsics, 4)\n", "\n", "b.setup_renderer(scaled_up_intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/10.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -2007,14 +2192,14 @@ "metadata": {}, "outputs": [], "source": [ - "depth = img[...,2]\n", + "depth = img[..., 2]\n", "minval = jnp.min(depth[depth > jnp.min(depth)])\n", "maxval = jnp.max(depth[depth < jnp.max(depth)])\n", "depth = depth.at[depth >= intrinsics.far].set(jnp.nan)\n", - "viz_img = np.array(b.viz.scale_image(b.get_depth_image(\n", - " depth, min=minval, max=maxval\n", - "), 3))\n", - "viz_img[viz_img.sum(-1) == 0,:] = 255.0\n", + "viz_img = np.array(\n", + " b.viz.scale_image(b.get_depth_image(depth, min=minval, max=maxval), 3)\n", + ")\n", + "viz_img[viz_img.sum(-1) == 0, :] = 255.0\n", "plt.imshow(viz_img)\n", "plt.xticks([])\n", "plt.yticks([])\n", @@ -2057,7 +2242,7 @@ "metadata": {}, "outputs": [], "source": [ - "jnp.linalg.norm(points,axis=-1)" + "jnp.linalg.norm(points, axis=-1)" ] }, { @@ -2077,48 +2262,57 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "fig = plt.figure(constrained_layout=True)\n", "\n", - "observation = trace[\"image\"]\n", + "observation = trace[\"image\"]\n", "\n", "# fig.suptitle(f\"Variance: {variance} Outlier Prob: {outlier_prob}\")\n", "widths = [1, 1]\n", "heights = [2]\n", - "spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", + "spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths, height_ratios=heights)\n", "\n", "ax = fig.add_subplot(spec[0, 0])\n", - "ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))\n", + "ax.imshow(jnp.array(b.get_depth_image(observation[..., 2], max=1.4)))\n", "ax.get_xaxis().set_visible(False)\n", "ax.get_yaxis().set_visible(False)\n", - "ax.set_title(f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\")\n", + "ax.set_title(\n", + " f\"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})\"\n", + ")\n", "# ax.set_title(f\"Observed Depth\")\n", "\n", "\n", "dist = 0.6\n", "ax = fig.add_subplot(spec[0, 1])\n", "ax.quiver(\n", - " sampled_contact_params[:,0],sampled_contact_params[:,1],\n", - " -jnp.sin(sampled_contact_params[:,2]),jnp.cos(sampled_contact_params[:,2]),\n", + " sampled_contact_params[:, 0],\n", + " sampled_contact_params[:, 1],\n", + " -jnp.sin(sampled_contact_params[:, 2]),\n", + " jnp.cos(sampled_contact_params[:, 2]),\n", " scale=3.0,\n", - " alpha=0.1\n", - " )\n", + " alpha=0.1,\n", + ")\n", "\n", "ax.quiver(\n", - " gt_contact[0],gt_contact[1],\n", - " -jnp.sin(gt_contact[2]), jnp.cos(gt_contact[2]),\n", + " gt_contact[0],\n", + " gt_contact[1],\n", + " -jnp.sin(gt_contact[2]),\n", + " jnp.cos(gt_contact[2]),\n", " scale=5.0,\n", " alpha=0.8,\n", - " color=\"red\"\n", + " color=\"red\",\n", ")\n", "\n", "ax.set_aspect(1.0)\n", "from matplotlib.patches import Rectangle\n", - "ax.add_patch(Rectangle((gt_contact[0]-width, gt_contact[1]-width), 2*width, 2*width,fill=None))\n", "\n", - "ax.set_xlim(gt_contact[0]-width-0.02, gt_contact[0]+width+0.02)\n", - "ax.set_ylim(gt_contact[1]-width-0.02, gt_contact[1]+width+0.02)" + "ax.add_patch(\n", + " Rectangle(\n", + " (gt_contact[0] - width, gt_contact[1] - width), 2 * width, 2 * width, fill=None\n", + " )\n", + ")\n", + "\n", + "ax.set_xlim(gt_contact[0] - width - 0.02, gt_contact[0] + width + 0.02)\n", + "ax.set_ylim(gt_contact[1] - width - 0.02, gt_contact[1] + width + 0.02)" ] }, { @@ -2133,8 +2327,8 @@ "best_cell_idx = jnp.abs(contact_param_grid - gt_contact).sum(1).argmin()\n", "print(gt_contact, contact_param_grid[best_cell_idx])\n", "normalize_log_weights = w1eights - b.logsumexp(weights)\n", - "assert(weights.shape[0] == contact_param_grid.shape[0])\n", - "volume = (width / num_position_grids)**2 * (2*jnp.pi / num_angle_grids)\n", + "assert weights.shape[0] == contact_param_grid.shape[0]\n", + "volume = (width / num_position_grids) ** 2 * (2 * jnp.pi / num_angle_grids)\n", "log_likelihood = normalize_log_weights[best_cell_idx] - jnp.log(volume)\n", "print(log_likelihood)" ] diff --git a/scripts/experiments/likelihood_debug/scene_parse_genjax.ipynb b/scripts/experiments/likelihood_debug/scene_parse_genjax.ipynb index 32b5bb90..ee8b5983 100644 --- a/scripts/experiments/likelihood_debug/scene_parse_genjax.ipynb +++ b/scripts/experiments/likelihood_debug/scene_parse_genjax.ipynb @@ -21,8 +21,9 @@ "import glob\n", "import bayes3d.neural\n", "import pickle\n", + "\n", "# Can be helpful for debugging:\n", - "# jax.config.update('jax_enable_checks', True) \n", + "# jax.config.update('jax_enable_checks', True)\n", "from bayes3d.neural.segmentation import carvekit_get_foreground_mask\n", "import genjax" ] @@ -64,9 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "paths = glob.glob(\n", - " \"panda_scans_v6/*.pkl\"\n", - ")\n", + "paths = glob.glob(\"panda_scans_v6/*.pkl\")\n", "all_data = pickle.load(open(paths[0], \"rb\"))\n", "IDX = 0\n", "data = all_data[IDX]" @@ -99,15 +98,17 @@ ], "source": [ "print(data[\"camera_image\"].keys())\n", - "K = data[\"camera_image\"]['camera_matrix'][0]\n", - "rgb = data[\"camera_image\"]['rgbPixels']\n", - "depth = data[\"camera_image\"]['depthPixels']\n", - "camera_pose = data[\"camera_image\"]['camera_pose']\n", + "K = data[\"camera_image\"][\"camera_matrix\"][0]\n", + "rgb = data[\"camera_image\"][\"rgbPixels\"]\n", + "depth = data[\"camera_image\"][\"depthPixels\"]\n", + "camera_pose = data[\"camera_image\"][\"camera_pose\"]\n", "camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)\n", - "fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]\n", - "h,w = depth.shape\n", + "fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n", + "h, w = depth.shape\n", "near = 0.001\n", - "rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))\n", + "rgbd_original = b.RGBD(\n", + " rgb, depth, camera_pose, b.Intrinsics(h, w, fx, fy, cx, cy, 0.001, 10000.0)\n", + ")\n", "b.get_rgb_image(rgbd_original.rgb)" ] }, @@ -130,7 +131,7 @@ } ], "source": [ - "b.get_depth_image(rgbd_original.depth,max=1.5)" + "b.get_depth_image(rgbd_original.depth, max=1.5)" ] }, { @@ -153,8 +154,11 @@ "source": [ "table_pose, plane_dims = b.utils.infer_table_plane(\n", " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics),\n", - " jnp.eye(4), rgbd_scaled_down.intrinsics, \n", - " ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=0.1\n", + " jnp.eye(4),\n", + " rgbd_scaled_down.intrinsics,\n", + " ransac_threshold=0.001,\n", + " inlier_threshold=0.001,\n", + " segmentation_threshold=0.1,\n", ")" ] }, @@ -166,7 +170,12 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"1\", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))\n", + "b.show_cloud(\n", + " \"1\",\n", + " b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(\n", + " -1, 3\n", + " ),\n", + ")\n", "b.show_pose(\"table\", table_pose)" ] }, @@ -195,17 +204,21 @@ "source": [ "b.setup_renderer(rgbd_scaled_down.intrinsics)\n", "b.RENDERER.add_mesh_from_file(\"toy_plane.ply\")\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(13+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(10+1).rjust(6, '0') + \".ply\")\n", - "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(13 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(10 + 1).rjust(6, \"0\") + \".ply\")\n", + "b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { "cell_type": "code", "execution_count": 13, + "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "outputs": [], "source": [ @@ -214,25 +227,30 @@ "num_position_grids = 51\n", "num_angle_grids = 51\n", "contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(\n", - " -width, -width, -ang,\n", - " width, width, ang,\n", - " num_position_grids,num_position_grids,num_angle_grids\n", + " -width,\n", + " -width,\n", + " -ang,\n", + " width,\n", + " width,\n", + " ang,\n", + " num_position_grids,\n", + " num_position_grids,\n", + " num_angle_grids,\n", ")\n", "\n", "grid_params = [\n", - " (0.5, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),\n", - " (0.05, jnp.pi/3, (15,15,15)),\n", - " (0.02, jnp.pi, (9,9,51))\n", - " , (0.01, jnp.pi/5, (15,15,15)),\n", - " (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))\n", + " (0.5, jnp.pi, (15, 15, 15)),\n", + " (0.2, jnp.pi, (15, 15, 15)),\n", + " (0.1, jnp.pi, (15, 15, 15)),\n", + " (0.05, jnp.pi / 3, (15, 15, 15)),\n", + " (0.02, jnp.pi, (9, 9, 51)),\n", + " (0.01, jnp.pi / 5, (15, 15, 15)),\n", + " (0.01, 0.0, (31, 31, 1)),\n", + " (0.05, 0.0, (31, 31, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]" ] }, @@ -251,6 +269,7 @@ { "cell_type": "code", "execution_count": 15, + "id": "acae54e37e7d407bbb7b55eff062a284", "metadata": {}, "outputs": [ { @@ -269,6 +288,7 @@ { "cell_type": "code", "execution_count": 39, + "id": "9a63283cbaf04dbcab1f6479b197f3a8", "metadata": {}, "outputs": [], "source": [] @@ -276,6 +296,7 @@ { "cell_type": "code", "execution_count": 71, + "id": "8dd0d8092fe74a7c96281538738b07e2", "metadata": {}, "outputs": [ { @@ -288,24 +309,35 @@ ], "source": [ "obs_img = b.unproject_depth_jit(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics)\n", - "weight, trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(3),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " \"face_parent_1\": 2,\n", - " \"face_child_1\": 3,\n", - " \"image\": obs_img,\n", - " \"variance\": 0.02,\n", - " \"outlier_prob\": 0.0001,\n", - " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0])\n", - "}), (\n", - " jnp.arange(1),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-12.0, -12.0, -22*jnp.pi]), jnp.array([12.0, 12.0, 22*jnp.pi])]),\n", - " b.RENDERER.model_box_dims)\n", + "weight, trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(3),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " \"face_parent_1\": 2,\n", + " \"face_child_1\": 3,\n", + " \"image\": obs_img,\n", + " \"variance\": 0.02,\n", + " \"outlier_prob\": 0.0001,\n", + " \"contact_params_1\": jnp.array([0.0, 0.0, 0.0]),\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(1),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [\n", + " jnp.array([-12.0, -12.0, -22 * jnp.pi]),\n", + " jnp.array([12.0, 12.0, 22 * jnp.pi]),\n", + " ]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " ),\n", ")\n", "b.viz_trace_meshcat(trace)\n", "print(trace.get_score())" @@ -314,6 +346,7 @@ { "cell_type": "code", "execution_count": 78, + "id": "72eea5119410473aa328ad9291626812", "metadata": {}, "outputs": [], "source": [ @@ -323,14 +356,14 @@ "\n", "\n", "def c2f_(potential_trace, number, contact_param_gridding_schedule):\n", - " updater = jax.vmap(lambda trace, v: trace.update(\n", - " key,\n", - " genjax.choice_map({\n", - " f\"contact_params_{number}\": v\n", - " }),\n", - " b.make_unknown_change_argdiffs(trace)\n", - " )[2].get_score(), in_axes=(None, 0))\n", - "\n", + " updater = jax.vmap(\n", + " lambda trace, v: trace.update(\n", + " key,\n", + " genjax.choice_map({f\"contact_params_{number}\": v}),\n", + " b.make_unknown_change_argdiffs(trace),\n", + " )[2].get_score(),\n", + " in_axes=(None, 0),\n", + " )\n", "\n", " cp = potential_trace[address]\n", " for cp_grid in contact_param_gridding_schedule:\n", @@ -339,25 +372,34 @@ " cp = cps[scores.argmax()]\n", " potential_trace = enumerators.update_choices(potential_trace, key, cp)\n", " return potential_trace, scores.argmax()\n", + "\n", + "\n", "c2f = jax.jit(c2f_)" ] }, { "cell_type": "code", "execution_count": 83, + "id": "8edb47106e1a46a883d545849b8ab81b", "metadata": {}, "outputs": [], "source": [ - "key = jax.random.split(key,2)[0]\n", + "key = jax.random.split(key, 2)[0]\n", "low, high = jnp.array([-0.4, -0.4, -jnp.pi]), jnp.array([0.4, 0.4, jnp.pi])\n", - "potential_trace = b.add_object_jit(trace, key, 2, 0, 2,3)\n", - "potential_trace = b.update_address(potential_trace, key, address, jax.random.uniform(key, shape=(3,),minval=low, maxval=high))\n", + "potential_trace = b.add_object_jit(trace, key, 2, 0, 2, 3)\n", + "potential_trace = b.update_address(\n", + " potential_trace,\n", + " key,\n", + " address,\n", + " jax.random.uniform(key, shape=(3,), minval=low, maxval=high),\n", + ")\n", "b.viz_trace_meshcat(potential_trace)" ] }, { "cell_type": "code", "execution_count": 86, + "id": "10185d26023b46108eb7d9f57d49d2b3", "metadata": {}, "outputs": [ { @@ -380,7 +422,7 @@ ], "source": [ "%%time\n", - "key = jax.random.split(key,2)[0]\n", + "key = jax.random.split(key, 2)[0]\n", "potential_trace = c2f(potential_trace, 1, contact_param_gridding_schedule)[0]\n", "print(potential_trace[\"contact_params_1\"])\n", "b.viz_trace_meshcat(potential_trace)" @@ -389,6 +431,7 @@ { "cell_type": "code", "execution_count": 24, + "id": "8763a12b2bbd4a93a75aff182afb95dc", "metadata": {}, "outputs": [ { diff --git a/scripts/experiments/mcs/cognitive-battery/model.ipynb b/scripts/experiments/mcs/cognitive-battery/model.ipynb index 48e67a98..38dabce7 100644 --- a/scripts/experiments/mcs/cognitive-battery/model.ipynb +++ b/scripts/experiments/mcs/cognitive-battery/model.ipynb @@ -94,10 +94,12 @@ "for frame_idx in tqdm(range(num_frames), desc=\"Loading frames\"):\n", " coord_image = np.array(unproject_depth(depth_images[frame_idx], intrinsics))\n", " segmentation_image = seg_maps[frame_idx].copy()\n", - " mask = np.logical_and.reduce([\n", - " *[(coord_image[:, :, i] > crops[i][0]) for i in range(len(crops))],\n", - " *[(coord_image[:, :, i] < crops[i][1]) for i in range(len(crops))], \n", - " ])\n", + " mask = np.logical_and.reduce(\n", + " [\n", + " *[(coord_image[:, :, i] > crops[i][0]) for i in range(len(crops))],\n", + " *[(coord_image[:, :, i] < crops[i][1]) for i in range(len(crops))],\n", + " ]\n", + " )\n", " mask = np.invert(mask)\n", "\n", " coord_image[mask, :] = 0.0\n", @@ -119,7 +121,8 @@ "meshes = []\n", "meshes_path = data_path.replace(\"videos\", \"meshes\")\n", "for mesh_name in os.listdir(meshes_path):\n", - " if not mesh_name.endswith(\".obj\"): continue\n", + " if not mesh_name.endswith(\".obj\"):\n", + " continue\n", " mesh_path = os.path.join(meshes_path, mesh_name)\n", " renderer.add_mesh_from_file(mesh_path, force=\"mesh\")\n", " meshes.append(mesh_name.replace(\".obj\", \"\"))\n", @@ -193,31 +196,34 @@ "def make_unfiform_grid(n, d):\n", " # d: number of enumerated proposals on each dimension (x, y, z).\n", " # n: the minimum and maximum position delta on each dimension (x, y, z).\n", - " return jax3dp3.make_translation_grid_enumeration(\n", - " -d, -d, -d, d, d, d, n, n, n\n", - " )\n", + " return jax3dp3.make_translation_grid_enumeration(-d, -d, -d, d, d, d, n, n, n)\n", + "\n", "\n", "def prior(new_pose, prev_poses):\n", - " new_pose = new_pose[:3,3]\n", + " new_pose = new_pose[:3, 3]\n", " gravity_shift = jnp.array([0.0, 0.2, 0.0])\n", - " velocity_shift = new_pose - prev_poses[:,:3,3].mean(axis=0)\n", - " \n", - " prev_pose = prev_poses[-1][:3,3]\n", + " velocity_shift = new_pose - prev_poses[:, :3, 3].mean(axis=0)\n", + "\n", + " prev_pose = prev_poses[-1][:3, 3]\n", " prior_shifts = gravity_shift + velocity_shift\n", - " weight = jax.scipy.stats.norm.logpdf(new_pose - (prev_pose + prior_shifts), loc=0, scale=0.1)\n", + " weight = jax.scipy.stats.norm.logpdf(\n", + " new_pose - (prev_pose + prior_shifts), loc=0, scale=0.1\n", + " )\n", " return weight.sum()\n", "\n", + "\n", "prior_parallel = jax.jit(jax.vmap(prior, in_axes=(0, None)))\n", "\n", + "\n", "def scorer(rendered_image, gt, r=0.1, op=0.005, ov=0.5):\n", " # Liklihood parameters\n", " # r: radius\n", " # op: outlier probability\n", " # ov: outlier volume\n", - " weight = jax3dp3.likelihood.threedp3_likelihood(\n", - " gt, rendered_image, r, op, ov\n", - " )\n", + " weight = jax3dp3.likelihood.threedp3_likelihood(gt, rendered_image, r, op, ov)\n", " return weight\n", + "\n", + "\n", "scorer_parallel = jax.jit(jax.vmap(scorer, in_axes=(0, None)))" ] }, @@ -265,7 +271,7 @@ "pose_estimates = init_poses.copy()\n", "past_poses = {i: deque([pose_estimates[i]]) for i in range(pose_estimates.shape[0])}\n", "for t in tqdm(range(start_t, start_t + num_steps)):\n", - " gt_image = jnp.array(coord_images[t]) \n", + " gt_image = jnp.array(coord_images[t])\n", " for _ in range(iterations_per_step):\n", " for i in range(n_objects):\n", " if i in set(containment_relations.values()):\n", @@ -275,37 +281,45 @@ " if i == reward_idx:\n", " occluded = utils.check_occlusion(renderer, pose_estimates, indices, i)\n", " if occluded:\n", - " containing_obj = utils.check_containment(renderer, pose_estimates, indices, i)\n", + " containing_obj = utils.check_containment(\n", + " renderer, pose_estimates, indices, i\n", + " )\n", " if containing_obj is not None:\n", " containment_relations[containing_obj] = i\n", " continue\n", - " \n", + "\n", " for d in [0.2, 0.1, 0.05]:\n", " translation_deltas = make_unfiform_grid(n=7, d=d)\n", " translation_deltas_full = jnp.tile(\n", " jnp.eye(4)[None, :, :],\n", " (translation_deltas.shape[0], pose_estimates.shape[0], 1, 1),\n", " )\n", - " translation_deltas_full = translation_deltas_full.at[:, i, :, :].set(translation_deltas)\n", + " translation_deltas_full = translation_deltas_full.at[:, i, :, :].set(\n", + " translation_deltas\n", + " )\n", " translation_proposals = jnp.einsum(\n", " \"bij,abjk->abik\", pose_estimates, translation_deltas_full\n", " )\n", - " images = renderer.render_multiobject_parallel(translation_proposals.transpose((1,0,2,3)), indices)\n", - " \n", - " weights = scorer_parallel(images, gt_image) + prior_parallel(translation_proposals[:,i], jnp.array(past_poses[i]))\n", + " images = renderer.render_multiobject_parallel(\n", + " translation_proposals.transpose((1, 0, 2, 3)), indices\n", + " )\n", + "\n", + " weights = scorer_parallel(images, gt_image) + prior_parallel(\n", + " translation_proposals[:, i], jnp.array(past_poses[i])\n", + " )\n", " best_weight_idx = jnp.argmax(weights)\n", " best_proposal = translation_proposals[best_weight_idx]\n", " pose_estimates = best_proposal\n", - " \n", + "\n", " past_poses[i].append(pose_estimates[i])\n", " if len(past_poses[i]) > num_past_poses:\n", " past_poses[i].popleft()\n", - " \n", + "\n", " for i, j in containment_relations.items():\n", " i_delta_pose = past_poses[i][-1] - past_poses[i][-2]\n", " new_pose_estimate = pose_estimates[j] + i_delta_pose\n", " pose_estimates = pose_estimates.at[j].set(new_pose_estimate)\n", - " \n", + "\n", " inferred_poses.append(pose_estimates.copy())" ] }, @@ -326,7 +340,7 @@ " apple_pose = poses[-1]\n", " rendered_apple = renderer.render_single_object(apple_pose, indices[-1])\n", " rendered_apple = [get_depth_image(rendered_apple[:, :, 2], max=5)]\n", - " \n", + "\n", " all_images.append(\n", " multi_panel(\n", " [rgb_viz, gt_depth_1, rendered_image, *rendered_apple],\n", diff --git a/scripts/experiments/mcs/cognitive-battery/swap_model.ipynb b/scripts/experiments/mcs/cognitive-battery/swap_model.ipynb index 6f013706..bd49956e 100644 --- a/scripts/experiments/mcs/cognitive-battery/swap_model.ipynb +++ b/scripts/experiments/mcs/cognitive-battery/swap_model.ipynb @@ -121,7 +121,8 @@ "meshes = []\n", "meshes_path = data_path.replace(\"videos\", \"meshes\")\n", "for mesh_name in os.listdir(meshes_path):\n", - " if not mesh_name.endswith(\".obj\"): continue\n", + " if not mesh_name.endswith(\".obj\"):\n", + " continue\n", " mesh_path = os.path.join(meshes_path, mesh_name)\n", " renderer.add_mesh_from_file(mesh_path, force=\"mesh\")\n", " meshes.append(mesh_name.replace(\".obj\", \"\"))\n", @@ -210,11 +211,20 @@ " -reward_d, -reward_d, -reward_d, reward_d, reward_d, reward_d, n, n, n\n", ")\n", "reward_deltas_mask = jnp.abs(translation_deltas_reward[:, -2, -1]) > 1e-6\n", - "translation_deltas_reward = translation_deltas_reward.at[reward_deltas_mask, -2, -1].set(0)\n", + "translation_deltas_reward = translation_deltas_reward.at[\n", + " reward_deltas_mask, -2, -1\n", + "].set(0)\n", + "\n", "\n", "def prior(new_pose, prev_pose):\n", - " weight = jax.scipy.stats.norm.pdf(new_pose[:3,3] - (prev_pose[:3,3] + jnp.array([0.0, 0.2, 0.0])), loc=0, scale=0.1)\n", + " weight = jax.scipy.stats.norm.pdf(\n", + " new_pose[:3, 3] - (prev_pose[:3, 3] + jnp.array([0.0, 0.2, 0.0])),\n", + " loc=0,\n", + " scale=0.1,\n", + " )\n", " return weight.sum()\n", + "\n", + "\n", "prior_parallel = jax.jit(jax.vmap(prior, in_axes=(0, None)))\n", "\n", "\n", @@ -223,6 +233,8 @@ " gt, rendered_image, r, outlier_prob, outlier_volume\n", " )\n", " return weight\n", + "\n", + "\n", "scorer_parallel = jax.jit(jax.vmap(scorer, in_axes=(0, None)))" ] }, @@ -284,44 +296,56 @@ "inferred_poses = []\n", "pose_estimates = init_poses.copy()\n", "for t in tqdm(range(start_t, start_t + num_steps)):\n", - " gt_image = jnp.array(coord_images[t]) \n", + " gt_image = jnp.array(coord_images[t])\n", " for _ in range(iterations_per_step):\n", " for i in range(n_objects):\n", " if i in set(containment_relations.values()):\n", " continue\n", - " translation_deltas = translation_deltas_global if i != reward_idx else translation_deltas_reward\n", - " \n", + " translation_deltas = (\n", + " translation_deltas_global\n", + " if i != reward_idx\n", + " else translation_deltas_reward\n", + " )\n", + "\n", " # Check for occlusion\n", " if i == reward_idx:\n", " occluded = utils.check_occlusion(renderer, pose_estimates, indices, i)\n", " if occluded:\n", - " containing_obj = utils.check_containment(renderer, pose_estimates, indices, i)\n", + " containing_obj = utils.check_containment(\n", + " renderer, pose_estimates, indices, i\n", + " )\n", " if containing_obj is not None:\n", " containment_relations[containing_obj] = i\n", " continue\n", - " \n", + "\n", " translation_deltas_full = jnp.tile(\n", " jnp.eye(4)[None, :, :],\n", " (translation_deltas.shape[0], pose_estimates.shape[0], 1, 1),\n", " )\n", - " translation_deltas_full = translation_deltas_full.at[:, i, :, :].set(translation_deltas)\n", + " translation_deltas_full = translation_deltas_full.at[:, i, :, :].set(\n", + " translation_deltas\n", + " )\n", " translation_proposals = jnp.einsum(\n", " \"bij,abjk->abik\", pose_estimates, translation_deltas_full\n", " )\n", - " images = renderer.render_multiobject_parallel(translation_proposals.transpose((1,0,2,3)), indices)\n", + " images = renderer.render_multiobject_parallel(\n", + " translation_proposals.transpose((1, 0, 2, 3)), indices\n", + " )\n", "\n", - " weights = scorer_parallel(images, gt_image) + prior_parallel(translation_proposals[:,i], pose_estimates[i])\n", + " weights = scorer_parallel(images, gt_image) + prior_parallel(\n", + " translation_proposals[:, i], pose_estimates[i]\n", + " )\n", " # weights = scorer_parallel(images, gt_image)\n", " best_weight_idx = jnp.argmax(weights)\n", " best_proposal = translation_proposals[best_weight_idx]\n", - " \n", + "\n", " objs_deltas[i] = best_proposal[i] - pose_estimates[i]\n", " pose_estimates = best_proposal\n", - " \n", + "\n", " for i, j in containment_relations.items():\n", " new_pose_estimate = pose_estimates[j] + objs_deltas[i]\n", " pose_estimates = pose_estimates.at[j].set(new_pose_estimate)\n", - " \n", + "\n", " inferred_poses.append(pose_estimates.copy())" ] }, @@ -342,7 +366,7 @@ " apple_pose = poses[-1]\n", " rendered_apple = renderer.render_single_object(apple_pose, indices[-1])\n", " rendered_apple = [get_depth_image(rendered_apple[:, :, 2], max=5)]\n", - " \n", + "\n", " all_images.append(\n", " multi_panel(\n", " [rgb_viz, gt_depth_1, rendered_image, *rendered_apple],\n", diff --git a/scripts/experiments/mcs/otp_gen/otp_gen/pipeline.ipynb b/scripts/experiments/mcs/otp_gen/otp_gen/pipeline.ipynb index bb8bdc4a..e50fe995 100644 --- a/scripts/experiments/mcs/otp_gen/otp_gen/pipeline.ipynb +++ b/scripts/experiments/mcs/otp_gen/otp_gen/pipeline.ipynb @@ -25,73 +25,114 @@ "\n", "def test_ycb_loading():\n", " bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('52', '1', bop_ycb_dir)\n", + " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\n", + " \"52\", \"1\", bop_ycb_dir\n", + " )\n", "\n", " b.setup_renderer(rgbd.intrinsics, num_layers=1)\n", "\n", - " model_dir =os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", - " for idx in range(1,22):\n", - " b.RENDERER.add_mesh_from_file(os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\"),scaling_factor=1.0/1000.0)\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + " for idx in range(1, 22):\n", + " b.RENDERER.add_mesh_from_file(\n", + " os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"),\n", + " scaling_factor=1.0 / 1000.0,\n", + " )\n", "\n", - " reconstruction_depth = b.RENDERER.render(gt_poses, gt_ids)[:,:,2]\n", + " reconstruction_depth = b.RENDERER.render(gt_poses, gt_ids)[:, :, 2]\n", " match_fraction = (jnp.abs(rgbd.depth - reconstruction_depth) < 0.05).mean()\n", " assert match_fraction > 0.2\n", "\n", + "\n", "bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('55', '22', bop_ycb_dir)\n", - "poses = jnp.concatenate([jnp.eye(4)[None,...], rgbd.camera_pose @ gt_poses],axis=0)\n", - "ids = jnp.concatenate([jnp.array([21]), gt_ids],axis=0)\n", + "rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\"55\", \"22\", bop_ycb_dir)\n", + "poses = jnp.concatenate([jnp.eye(4)[None, ...], rgbd.camera_pose @ gt_poses], axis=0)\n", + "ids = jnp.concatenate([jnp.array([21]), gt_ids], axis=0)\n", "\n", "\n", "b.setup_renderer(rgbd.intrinsics, num_layers=1)\n", "\n", - "model_dir =os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", - "for idx in range(1,22):\n", - " b.RENDERER.add_mesh_from_file(os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\"),scaling_factor=1.0/1000.0)\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + "for idx in range(1, 22):\n", + " b.RENDERER.add_mesh_from_file(\n", + " os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"),\n", + " scaling_factor=1.0 / 1000.0,\n", + " )\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n", + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")\n", "\n", "scene_graph = b.scene_graph.SceneGraph(\n", " root_poses=poses,\n", " box_dimensions=b.RENDERER.model_box_dims[ids],\n", " parents=jnp.full(poses.shape[0], -1),\n", - " contact_params=jnp.zeros((poses.shape[0],3)),\n", + " contact_params=jnp.zeros((poses.shape[0], 3)),\n", " face_parent=jnp.zeros(poses.shape[0], dtype=jnp.int32),\n", " face_child=jnp.zeros(poses.shape[0], dtype=jnp.int32),\n", ")\n", "assert jnp.isclose(scene_graph.get_poses(), poses).all()\n", "\n", - "def get_slack(scene_graph, parent_object_index, child_object_index, face_parent, face_child):\n", + "\n", + "def get_slack(\n", + " scene_graph, parent_object_index, child_object_index, face_parent, face_child\n", + "):\n", " parent_pose = scene_graph.get_poses()[parent_object_index]\n", " child_pose = scene_graph.get_poses()[child_object_index]\n", " dims_parent = scene_graph.box_dimensions[parent_object_index]\n", " dims_child = scene_graph.box_dimensions[child_object_index]\n", - " parent_contact_plane = parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent]\n", - " child_contact_plane = child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child]\n", + " parent_contact_plane = (\n", + " parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent]\n", + " )\n", + " child_contact_plane = (\n", + " child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child]\n", + " )\n", + "\n", + " contact_params, slack = b.scene_graph.closest_approximate_contact_params(\n", + " parent_contact_plane, child_contact_plane\n", + " )\n", + " return (\n", + " jnp.array([parent_object_index, child_object_index, face_parent, face_child]),\n", + " contact_params,\n", + " slack,\n", + " )\n", "\n", - " contact_params, slack = b.scene_graph.closest_approximate_contact_params(parent_contact_plane, child_contact_plane)\n", - " return jnp.array([parent_object_index, child_object_index, face_parent, face_child]), contact_params, slack\n", "\n", "add_edge_scene_graph = jax.jit(b.scene_graph.add_edge_scene_graph)\n", "\n", "N = poses.shape[0]\n", "# b.setup_visualizer()\n", "\n", - "get_slack_vmap = jax.jit(b.utils.multivmap(get_slack, (False, False, False, True, True)))\n", + "get_slack_vmap = jax.jit(\n", + " b.utils.multivmap(get_slack, (False, False, False, True, True))\n", + ")\n", "\n", - "edges = [(0,1),(0,2),(0,3),(0,4),(0,6),(2,5)]\n", - "for i,j in edges:\n", - " settings, contact_params, slacks = get_slack_vmap(scene_graph, i,j, jnp.arange(6), jnp.arange(6))\n", - " settings = settings.reshape(-1,settings.shape[-1])\n", - " contact_params = contact_params.reshape(-1,contact_params.shape[-1])\n", - " error = jnp.abs(slacks - jnp.eye(4)).sum([-1,-2]).reshape(-1)\n", + "edges = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 6), (2, 5)]\n", + "for i, j in edges:\n", + " settings, contact_params, slacks = get_slack_vmap(\n", + " scene_graph, i, j, jnp.arange(6), jnp.arange(6)\n", + " )\n", + " settings = settings.reshape(-1, settings.shape[-1])\n", + " contact_params = contact_params.reshape(-1, contact_params.shape[-1])\n", + " error = jnp.abs(slacks - jnp.eye(4)).sum([-1, -2]).reshape(-1)\n", " indices = jnp.argsort(error.reshape(-1))\n", "\n", - " parent_object_index, child_object_index, face_parent, face_child = settings[indices[0]]\n", - " scene_graph = add_edge_scene_graph(scene_graph,parent_object_index, child_object_index, face_parent, face_child, contact_params[indices[0]])\n", + " parent_object_index, child_object_index, face_parent, face_child = settings[\n", + " indices[0]\n", + " ]\n", + " scene_graph = add_edge_scene_graph(\n", + " scene_graph,\n", + " parent_object_index,\n", + " child_object_index,\n", + " face_parent,\n", + " face_child,\n", + " contact_params[indices[0]],\n", + " )\n", "\n", "node_names = np.array([*b.utils.ycb_loader.MODEL_NAMES, \"table\"])\n", - "scene_graph.table_visualize(\"graph.png\", node_names=list(map(str,enumerate(node_names[ids]))))" + "scene_graph.table_visualize(\n", + " \"graph.png\", node_names=list(map(str, enumerate(node_names[ids])))\n", + ")" ] }, { @@ -100,12 +141,15 @@ "metadata": {}, "outputs": [], "source": [ - "from PIL import Image as im \n", + "from PIL import Image as im\n", + "\n", "# Show YCB Image\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "image.save(\"base_ycb/55_22.png\")\n", - "scene_graph.table_visualize(\"scene_graphs/55_22.svg\", node_names=list(map(str,enumerate(node_names[ids]))))\n", + "scene_graph.table_visualize(\n", + " \"scene_graphs/55_22.svg\", node_names=list(map(str, enumerate(node_names[ids])))\n", + ")\n", "image" ] }, @@ -115,7 +159,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Show Scene Graph \n", + "# Show Scene Graph\n", "graph_vis = im.open(\"graph.png\")\n", "graph_vis" ] @@ -127,20 +171,27 @@ "outputs": [], "source": [ "from bayes3d._rendering.photorealistic_renderers.kubric_interface import render_many\n", - "# create mesh_paths, could take from scene_graph construction. \n", - "model_dir =os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", + "\n", + "# create mesh_paths, could take from scene_graph construction.\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", "model_paths = []\n", "for model_name in node_names[ids]:\n", - " model_paths.append(os.path.join(model_dir,model_name +\"/textured.obj\"))\n", + " model_paths.append(os.path.join(model_dir, model_name + \"/textured.obj\"))\n", "\n", - "# add table as root node \n", + "# add table as root node\n", "model_paths[0] = os.path.join(b.utils.get_assets_dir(), \"sample_objs/table.obj\")\n", "poses = scene_graph.get_poses()\n", "intrinsics = rgbd.intrinsics\n", "scaling_factor = 1.0\n", "\n", - "# Pass through and render \n", - "outputs = render_many(model_paths, poses[None,...], intrinsics, scaling_factor=scaling_factor, camera_pose = rgbd.camera_pose)" + "# Pass through and render\n", + "outputs = render_many(\n", + " model_paths,\n", + " poses[None, ...],\n", + " intrinsics,\n", + " scaling_factor=scaling_factor,\n", + " camera_pose=rgbd.camera_pose,\n", + ")" ] }, { @@ -184,31 +235,40 @@ "import os\n", "from tqdm import tqdm\n", "from bayes3d._rendering.photorealistic_renderers.kubric_interface import render_many\n", - "from PIL import Image as im \n", + "from PIL import Image as im\n", "\n", "\n", "def ycb_load(test_set, image_number):\n", " test_set = str(test_set)\n", " image_number = str(image_number)\n", " bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(test_set, image_number, bop_ycb_dir)\n", - " poses = jnp.concatenate([jnp.eye(4)[None,...], rgbd.camera_pose @ gt_poses],axis=0)\n", - " ids = jnp.concatenate([jnp.array([21]), gt_ids],axis=0)\n", - "\n", + " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\n", + " test_set, image_number, bop_ycb_dir\n", + " )\n", + " poses = jnp.concatenate(\n", + " [jnp.eye(4)[None, ...], rgbd.camera_pose @ gt_poses], axis=0\n", + " )\n", + " ids = jnp.concatenate([jnp.array([21]), gt_ids], axis=0)\n", "\n", - " b.setup_renderer(rgbd.intrinsics, num_layers=1) \n", + " b.setup_renderer(rgbd.intrinsics, num_layers=1)\n", "\n", - " model_dir =os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", - " for idx in range(1,22):\n", - " b.RENDERER.add_mesh_from_file(os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\"),scaling_factor=1.0/1000.0)\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + " for idx in range(1, 22):\n", + " b.RENDERER.add_mesh_from_file(\n", + " os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"),\n", + " scaling_factor=1.0 / 1000.0,\n", + " )\n", "\n", - " b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n", + " b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + " )\n", "\n", " scene_graph = b.scene_graph.SceneGraph(\n", " root_poses=poses,\n", " box_dimensions=b.RENDERER.model_box_dims[ids],\n", " parents=jnp.full(poses.shape[0], -1),\n", - " contact_params=jnp.zeros((poses.shape[0],3)),\n", + " contact_params=jnp.zeros((poses.shape[0], 3)),\n", " face_parent=jnp.zeros(poses.shape[0], dtype=jnp.int32),\n", " face_child=jnp.zeros(poses.shape[0], dtype=jnp.int32),\n", " )\n", @@ -218,53 +278,95 @@ "\n", " N = poses.shape[0]\n", "\n", - " get_slack_vmap = jax.jit(b.utils.multivmap(get_slack, (False, False, False, True, True)))\n", + " get_slack_vmap = jax.jit(\n", + " b.utils.multivmap(get_slack, (False, False, False, True, True))\n", + " )\n", " return scene_graph, rgbd, ids\n", "\n", - "def get_slack(scene_graph, parent_object_index, child_object_index, face_parent, face_child):\n", + "\n", + "def get_slack(\n", + " scene_graph, parent_object_index, child_object_index, face_parent, face_child\n", + "):\n", " parent_pose = scene_graph.get_poses()[parent_object_index]\n", " child_pose = scene_graph.get_poses()[child_object_index]\n", " dims_parent = scene_graph.box_dimensions[parent_object_index]\n", " dims_child = scene_graph.box_dimensions[child_object_index]\n", - " parent_contact_plane = parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent]\n", - " child_contact_plane = child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child]\n", + " parent_contact_plane = (\n", + " parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent]\n", + " )\n", + " child_contact_plane = (\n", + " child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child]\n", + " )\n", + "\n", + " contact_params, slack = b.scene_graph.closest_approximate_contact_params(\n", + " parent_contact_plane, child_contact_plane\n", + " )\n", + " return (\n", + " jnp.array([parent_object_index, child_object_index, face_parent, face_child]),\n", + " contact_params,\n", + " slack,\n", + " )\n", + "\n", "\n", - " contact_params, slack = b.scene_graph.closest_approximate_contact_params(parent_contact_plane, child_contact_plane)\n", - " return jnp.array([parent_object_index, child_object_index, face_parent, face_child]), contact_params, slack\n", "add_edge_scene_graph = jax.jit(b.scene_graph.add_edge_scene_graph)\n", - "get_slack_vmap = jax.jit(b.utils.multivmap(get_slack, (False, False, False, True, True)))\n", - "\n", - "def vis_scene_graph(scene_graph,edges, name = \"graph.png\"): \n", - " for i,j in edges:\n", - " settings, contact_params, slacks = get_slack_vmap(scene_graph, i,j, jnp.arange(6), jnp.arange(6))\n", - " settings = settings.reshape(-1,settings.shape[-1])\n", - " contact_params = contact_params.reshape(-1,contact_params.shape[-1])\n", - " error = jnp.abs(slacks - jnp.eye(4)).sum([-1,-2]).reshape(-1)\n", + "get_slack_vmap = jax.jit(\n", + " b.utils.multivmap(get_slack, (False, False, False, True, True))\n", + ")\n", + "\n", + "\n", + "def vis_scene_graph(scene_graph, edges, name=\"graph.png\"):\n", + " for i, j in edges:\n", + " settings, contact_params, slacks = get_slack_vmap(\n", + " scene_graph, i, j, jnp.arange(6), jnp.arange(6)\n", + " )\n", + " settings = settings.reshape(-1, settings.shape[-1])\n", + " contact_params = contact_params.reshape(-1, contact_params.shape[-1])\n", + " error = jnp.abs(slacks - jnp.eye(4)).sum([-1, -2]).reshape(-1)\n", " indices = jnp.argsort(error.reshape(-1))\n", "\n", - " parent_object_index, child_object_index, face_parent, face_child = settings[indices[0]]\n", - " scene_graph = add_edge_scene_graph(scene_graph,parent_object_index, child_object_index, face_parent, face_child, contact_params[indices[0]])\n", + " parent_object_index, child_object_index, face_parent, face_child = settings[\n", + " indices[0]\n", + " ]\n", + " scene_graph = add_edge_scene_graph(\n", + " scene_graph,\n", + " parent_object_index,\n", + " child_object_index,\n", + " face_parent,\n", + " face_child,\n", + " contact_params[indices[0]],\n", + " )\n", "\n", " node_names = np.array([*b.utils.ycb_loader.MODEL_NAMES, \"table\"])\n", - " scene_graph.table_visualize(name, node_names=list(map(str,enumerate(node_names[ids]))))\n", - " return node_names \n", + " scene_graph.table_visualize(\n", + " name, node_names=list(map(str, enumerate(node_names[ids])))\n", + " )\n", + " return node_names\n", + "\n", "\n", "def render_sg_kubric(node_names, scene_graph, rgbd):\n", - " # create mesh_paths, could take from scene_graph construction. \n", - " model_dir =os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", + " # create mesh_paths, could take from scene_graph construction.\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", " model_paths = []\n", " for model_name in node_names[ids]:\n", - " model_paths.append(os.path.join(model_dir,model_name +\"/textured.obj\"))\n", + " model_paths.append(os.path.join(model_dir, model_name + \"/textured.obj\"))\n", "\n", - " # add table as root node \n", + " # add table as root node\n", " model_paths[0] = os.path.join(b.utils.get_assets_dir(), \"sample_objs/plane.obj\")\n", " poses = scene_graph.get_poses()\n", " intrinsics = rgbd.intrinsics\n", " scaling_factor = 1.0\n", "\n", " # Pass through and render (note; it seems the ycb requires a rotation)\n", - " camera_pose = rgbd.camera_pose @ b.t3d.transform_from_axis_angle(jnp.array([1.0, 0.0,0.0]), jnp.pi)\n", - " outputs = render_many(model_paths, poses[None,...], intrinsics, scaling_factor=scaling_factor, camera_pose = rgbd.camera_pose)\n", + " camera_pose = rgbd.camera_pose @ b.t3d.transform_from_axis_angle(\n", + " jnp.array([1.0, 0.0, 0.0]), jnp.pi\n", + " )\n", + " outputs = render_many(\n", + " model_paths,\n", + " poses[None, ...],\n", + " intrinsics,\n", + " scaling_factor=scaling_factor,\n", + " camera_pose=rgbd.camera_pose,\n", + " )\n", " return outputs" ] }, @@ -281,10 +383,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 49,1\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 49, 1\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -296,10 +398,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (0,3), (4,1)]\n", + "edges = [(0, 2), (0, 4), (0, 3), (4, 1)]\n", "\n", "name = f\"{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names)\n", "graph_vis = im.open(name)\n", "graph_vis" @@ -338,9 +440,9 @@ "metadata": {}, "outputs": [], "source": [ - "scene_graph, rgbd, ids = ycb_load(51,1)\n", + "scene_graph, rgbd, ids = ycb_load(51, 1)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "image.save(f\"base_ycb/{51}_{1}.png\")\n", "image" ] @@ -351,10 +453,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,1), (0,3),(0,4),(0,5),(5,2)]\n", + "edges = [(0, 1), (0, 3), (0, 4), (0, 5), (5, 2)]\n", "\n", "name = \"scene_graphs/51_1.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names[ids])\n", "graph_vis = im.open(name)\n", "graph_vis" @@ -385,10 +487,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 50,620\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 50, 620\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -400,10 +502,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (0,3), (0,1), (0,5)]\n", + "edges = [(0, 2), (0, 4), (0, 3), (0, 1), (0, 5)]\n", "\n", "name = f\"scene_graphs/{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names)\n", "# graph_vis = im.open(name)\n", "# graph_vis" @@ -434,10 +536,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 54,1\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 54, 1\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -449,10 +551,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (0,3), (0,1), (0,5)]\n", + "edges = [(0, 2), (0, 4), (0, 3), (0, 1), (0, 5)]\n", "\n", "name = f\"scene_graphs/{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names)" ] }, @@ -481,10 +583,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 56,1\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 56, 1\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -496,10 +598,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (0,3), (0,1), (0,5)]\n", + "edges = [(0, 2), (0, 4), (0, 3), (0, 1), (0, 5)]\n", "\n", "name = f\"scene_graphs/{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names)" ] }, @@ -528,10 +630,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 57,1\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 57, 1\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -543,10 +645,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (5,3), (0,1), (0,5)]\n", + "edges = [(0, 2), (0, 4), (5, 3), (0, 1), (0, 5)]\n", "\n", "name = f\"scene_graphs/{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names[ids])" ] }, @@ -575,10 +677,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 58,30\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 58, 30\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -590,10 +692,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (0,3), (0,1), (0,5)]\n", + "edges = [(0, 2), (0, 4), (0, 3), (0, 1), (0, 5)]\n", "\n", "name = f\"scene_graphs/{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names)" ] }, @@ -622,10 +724,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 59,1\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 59, 1\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -637,10 +739,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (0,3), (0,1), (0,5), (0,6)]\n", + "edges = [(0, 2), (0, 4), (0, 3), (0, 1), (0, 5), (0, 6)]\n", "\n", "name = f\"scene_graphs/{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names)" ] }, @@ -669,10 +771,10 @@ "metadata": {}, "outputs": [], "source": [ - "set_num, img_num = 48,1\n", - "scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", + "set_num, img_num = 48, 1\n", + "scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", "np_rep = np.array(rgbd.rgb.copy())\n", - "image = im.fromarray(np_rep)\n", + "image = im.fromarray(np_rep)\n", "file = f\"{set_num}_{img_num}\"\n", "image.save(f\"base_ycb/{file}.png\")\n", "image" @@ -684,10 +786,10 @@ "metadata": {}, "outputs": [], "source": [ - "edges = [(0,2), (0,4), (1,3), (0,1), (0,5)]\n", + "edges = [(0, 2), (0, 4), (1, 3), (0, 1), (0, 5)]\n", "\n", "name = f\"scene_graphs/{file}.svg\"\n", - "node_names = vis_scene_graph(scene_graph, edges,name)\n", + "node_names = vis_scene_graph(scene_graph, edges, name)\n", "print(node_names[ids])" ] }, @@ -726,32 +828,52 @@ "import os\n", "from tqdm import tqdm\n", "from bayes3d._rendering.photorealistic_renderers.kubric_interface import render_many\n", - "from PIL import Image as im \n", + "from PIL import Image as im\n", + "\n", + "scenes = [\n", + " [49, 1],\n", + " [51, 1],\n", + " [50, 620],\n", + " [54, 1],\n", + " [56, 1],\n", + " [57, 1],\n", + " [58, 30],\n", + " [59, 1],\n", + " [48, 1],\n", + "]\n", "\n", - "scenes = [[49,1],[51,1], [50,620], [54,1], [56,1], [57,1], [58,30], [59,1], [48,1]]\n", "\n", "def ycb_load(test_set, image_number):\n", " test_set = str(test_set)\n", " image_number = str(image_number)\n", " bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv\")\n", - " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(test_set, image_number, bop_ycb_dir)\n", - " poses = jnp.concatenate([jnp.eye(4)[None,...], rgbd.camera_pose @ gt_poses],axis=0)\n", - " ids = jnp.concatenate([jnp.array([21]), gt_ids],axis=0)\n", - "\n", + " rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img(\n", + " test_set, image_number, bop_ycb_dir\n", + " )\n", + " poses = jnp.concatenate(\n", + " [jnp.eye(4)[None, ...], rgbd.camera_pose @ gt_poses], axis=0\n", + " )\n", + " ids = jnp.concatenate([jnp.array([21]), gt_ids], axis=0)\n", "\n", - " b.setup_renderer(rgbd.intrinsics, num_layers=1) \n", + " b.setup_renderer(rgbd.intrinsics, num_layers=1)\n", "\n", - " model_dir =os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", - " for idx in range(1,22):\n", - " b.RENDERER.add_mesh_from_file(os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\"),scaling_factor=1.0/1000.0)\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", + " for idx in range(1, 22):\n", + " b.RENDERER.add_mesh_from_file(\n", + " os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"),\n", + " scaling_factor=1.0 / 1000.0,\n", + " )\n", "\n", - " b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n", + " b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + " )\n", "\n", " scene_graph = b.scene_graph.SceneGraph(\n", " root_poses=poses,\n", " box_dimensions=b.RENDERER.model_box_dims[ids],\n", " parents=jnp.full(poses.shape[0], -1),\n", - " contact_params=jnp.zeros((poses.shape[0],3)),\n", + " contact_params=jnp.zeros((poses.shape[0], 3)),\n", " face_parent=jnp.zeros(poses.shape[0], dtype=jnp.int32),\n", " face_child=jnp.zeros(poses.shape[0], dtype=jnp.int32),\n", " )\n", @@ -761,66 +883,107 @@ "\n", " N = poses.shape[0]\n", "\n", - " get_slack_vmap = jax.jit(b.utils.multivmap(get_slack, (False, False, False, True, True)))\n", + " get_slack_vmap = jax.jit(\n", + " b.utils.multivmap(get_slack, (False, False, False, True, True))\n", + " )\n", " return scene_graph, rgbd, ids\n", "\n", - "def get_slack(scene_graph, parent_object_index, child_object_index, face_parent, face_child):\n", + "\n", + "def get_slack(\n", + " scene_graph, parent_object_index, child_object_index, face_parent, face_child\n", + "):\n", " parent_pose = scene_graph.get_poses()[parent_object_index]\n", " child_pose = scene_graph.get_poses()[child_object_index]\n", " dims_parent = scene_graph.box_dimensions[parent_object_index]\n", " dims_child = scene_graph.box_dimensions[child_object_index]\n", - " parent_contact_plane = parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent]\n", - " child_contact_plane = child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child]\n", + " parent_contact_plane = (\n", + " parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent]\n", + " )\n", + " child_contact_plane = (\n", + " child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child]\n", + " )\n", + "\n", + " contact_params, slack = b.scene_graph.closest_approximate_contact_params(\n", + " parent_contact_plane, child_contact_plane\n", + " )\n", + " return (\n", + " jnp.array([parent_object_index, child_object_index, face_parent, face_child]),\n", + " contact_params,\n", + " slack,\n", + " )\n", + "\n", "\n", - " contact_params, slack = b.scene_graph.closest_approximate_contact_params(parent_contact_plane, child_contact_plane)\n", - " return jnp.array([parent_object_index, child_object_index, face_parent, face_child]), contact_params, slack\n", "add_edge_scene_graph = jax.jit(b.scene_graph.add_edge_scene_graph)\n", - "get_slack_vmap = jax.jit(b.utils.multivmap(get_slack, (False, False, False, True, True)))\n", - "\n", - "def vis_scene_graph(scene_graph, edges, name = \"graph.png\"): \n", - " for i,j in edges:\n", - " settings, contact_params, slacks = get_slack_vmap(scene_graph, i,j, jnp.arange(6), jnp.arange(6))\n", - " settings = settings.reshape(-1,settings.shape[-1])\n", - " contact_params = contact_params.reshape(-1,contact_params.shape[-1])\n", - " error = jnp.abs(slacks - jnp.eye(4)).sum([-1,-2]).reshape(-1)\n", + "get_slack_vmap = jax.jit(\n", + " b.utils.multivmap(get_slack, (False, False, False, True, True))\n", + ")\n", + "\n", + "\n", + "def vis_scene_graph(scene_graph, edges, name=\"graph.png\"):\n", + " for i, j in edges:\n", + " settings, contact_params, slacks = get_slack_vmap(\n", + " scene_graph, i, j, jnp.arange(6), jnp.arange(6)\n", + " )\n", + " settings = settings.reshape(-1, settings.shape[-1])\n", + " contact_params = contact_params.reshape(-1, contact_params.shape[-1])\n", + " error = jnp.abs(slacks - jnp.eye(4)).sum([-1, -2]).reshape(-1)\n", " indices = jnp.argsort(error.reshape(-1))\n", "\n", - " parent_object_index, child_object_index, face_parent, face_child = settings[indices[0]]\n", - " scene_graph = add_edge_scene_graph(scene_graph,parent_object_index, child_object_index, face_parent, face_child, contact_params[indices[0]])\n", + " parent_object_index, child_object_index, face_parent, face_child = settings[\n", + " indices[0]\n", + " ]\n", + " scene_graph = add_edge_scene_graph(\n", + " scene_graph,\n", + " parent_object_index,\n", + " child_object_index,\n", + " face_parent,\n", + " face_child,\n", + " contact_params[indices[0]],\n", + " )\n", "\n", " node_names = np.array([*b.utils.ycb_loader.MODEL_NAMES, \"table\"])\n", - " scene_graph.table_visualize(name, node_names=list(map(str,enumerate(node_names[ids]))))\n", - " return node_names \n", + " scene_graph.table_visualize(\n", + " name, node_names=list(map(str, enumerate(node_names[ids])))\n", + " )\n", + " return node_names\n", + "\n", "\n", "def render_sg_kubric(node_names, scene_graph, rgbd):\n", - " # create mesh_paths, could take from scene_graph construction. \n", - " model_dir =os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", + " # create mesh_paths, could take from scene_graph construction.\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", " model_paths = []\n", " for model_name in node_names[ids]:\n", - " model_paths.append(os.path.join(model_dir,model_name +\"/textured.obj\"))\n", + " model_paths.append(os.path.join(model_dir, model_name + \"/textured.obj\"))\n", "\n", - " # add table as root node \n", + " # add table as root node\n", " model_paths[0] = os.path.join(b.utils.get_assets_dir(), \"sample_objs/plane.obj\")\n", " poses = scene_graph.get_poses()\n", " intrinsics = rgbd.intrinsics\n", " scaling_factor = 1.0\n", "\n", - " # Pass through and render \n", - " outputs = render_many(model_paths, poses[None,...], intrinsics, scaling_factor=scaling_factor, camera_pose = rgbd.camera_pose)\n", + " # Pass through and render\n", + " outputs = render_many(\n", + " model_paths,\n", + " poses[None, ...],\n", + " intrinsics,\n", + " scaling_factor=scaling_factor,\n", + " camera_pose=rgbd.camera_pose,\n", + " )\n", " return outputs\n", "\n", - "def model_paths(node_names, scene_graph, rgbd): \n", - " model_dir =os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", + "\n", + "def model_paths(node_names, scene_graph, rgbd):\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", " model_paths = []\n", " for model_name in node_names[ids]:\n", - " model_paths.append(os.path.join(model_dir,model_name +\"/textured.obj\"))\n", + " model_paths.append(os.path.join(model_dir, model_name + \"/textured.obj\"))\n", "\n", - " # add table as root node \n", + " # add table as root node\n", " model_paths[0] = os.path.join(b.utils.get_assets_dir(), \"sample_objs/plane.obj\")\n", " poses = scene_graph.get_poses()\n", " intrinsics = rgbd.intrinsics\n", " scaling_factor = 1.0\n", - " scene = pbs.Scene() " + " scene = pbs.Scene()" ] }, { @@ -829,17 +992,17 @@ "metadata": {}, "outputs": [], "source": [ - "for set_num, img_num in scenes: \n", - " scene_graph, rgbd, ids = ycb_load(set_num,img_num)\n", - " edges = [] \n", + "for set_num, img_num in scenes:\n", + " scene_graph, rgbd, ids = ycb_load(set_num, img_num)\n", + " edges = []\n", " node_names = vis_scene_graph(scene_graph, edges)\n", "\n", - " model_dir =os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", + " model_dir = os.path.join(b.utils.get_assets_dir(), \"ycb_video_models/models\")\n", " model_paths = []\n", " for model_name in node_names[ids]:\n", - " model_paths.append(os.path.join(model_dir,model_name +\"/textured.obj\"))\n", + " model_paths.append(os.path.join(model_dir, model_name + \"/textured.obj\"))\n", "\n", - " # add table as root node \n", + " # add table as root node\n", " model_paths[0] = os.path.join(b.utils.get_assets_dir(), \"sample_objs/plane.obj\")\n", " poses = scene_graph.get_poses()\n", " intrinsics = rgbd.intrinsics\n", @@ -847,25 +1010,25 @@ " scene = pbs.Scene(floor=True)\n", " scene.set_camera_pose(rgbd.camera_pose)\n", " scene.camera.set_intrinsics(intrinsics)\n", - " scene.set_gravity([0,0,-1])\n", + " scene.set_gravity([0, 0, -1])\n", "\n", " body_names = node_names[ids]\n", - " for i, name in enumerate(body_names): \n", - " if i != 0: \n", + " for i, name in enumerate(body_names):\n", + " if i != 0:\n", " body_pose = poses[i]\n", " mesh = model_paths[i]\n", - " body = pbs.make_body_from_obj_pose(mesh, body_pose, id = name)\n", + " body = pbs.make_body_from_obj_pose(mesh, body_pose, id=name)\n", " body.set_restitution(0)\n", " scene.add_body(body)\n", "\n", " path = \"../assets/sample_objs/sphere.obj\"\n", - " position = [0,-1,.1]\n", - " velocity = [0,4,0]\n", - " scale = np.array([1,1,1]) * .1 \n", - " bowling = pbs.make_body_from_obj(path, position, scale = scale,id = \"bowling\")\n", + " position = [0, -1, 0.1]\n", + " velocity = [0, 4, 0]\n", + " scale = np.array([1, 1, 1]) * 0.1\n", + " bowling = pbs.make_body_from_obj(path, position, scale=scale, id=\"bowling\")\n", " bowling.set_velocity(velocity)\n", " bowling.set_restitution(0)\n", - " bowling.set_color(np.array([.1,.1,.1]))\n", + " bowling.set_color(np.array([0.1, 0.1, 0.1]))\n", " bowling.set_mass = 10000\n", " scene.add_body(bowling)\n", "\n", diff --git a/scripts/experiments/mcs/otp_gen/otp_gen/pyb_sim.ipynb b/scripts/experiments/mcs/otp_gen/otp_gen/pyb_sim.ipynb index 1142ad27..5d554ff3 100644 --- a/scripts/experiments/mcs/otp_gen/otp_gen/pyb_sim.ipynb +++ b/scripts/experiments/mcs/otp_gen/otp_gen/pyb_sim.ipynb @@ -22,11 +22,11 @@ "\n", "# define a position and create a sphere\n", "position = np.array([0, -2, 2])\n", - "sphere = pbs.make_sphere(position, [1,1,1]) \n", + "sphere = pbs.make_sphere(position, [1, 1, 1])\n", "\n", "# add the sphere to the scene\n", "scene.add_body(sphere)\n", - "image,depth, segm= scene.render(pbs.pybullet_render) \n", + "image, depth, segm = scene.render(pbs.pybullet_render)\n", "image" ] }, @@ -56,12 +56,12 @@ "d20_pose = np.eye(4)\n", "d20_rot = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]])\n", "d20_pos = np.array([0, 0, 1])\n", - "d20_pose[:3,:3] = d20_rot\n", - "d20_pose[:3,3] = d20_pos\n", + "d20_pose[:3, :3] = d20_rot\n", + "d20_pose[:3, 3] = d20_pos\n", "\n", "# create a d20\n", - "d20 = pbs.make_body_from_obj_pose(path_to_mesh, d20_pose, id = \"d20\")\n", - "d20.set_color([1,0,1])\n", + "d20 = pbs.make_body_from_obj_pose(path_to_mesh, d20_pose, id=\"d20\")\n", + "d20.set_color([1, 0, 1])\n", "scene.add_body(d20)\n", "image, depth, segm = scene.render(pbs.pybullet_render)\n", "image" @@ -96,7 +96,7 @@ "# Create a colormap for the number of classes\n", "# num_classes = np.max(class_array) + 1\n", "num_classes = 4\n", - "cmap = plt.cm.get_cmap('tab10', num_classes)\n", + "cmap = plt.cm.get_cmap(\"tab10\", num_classes)\n", "\n", "# Create an RGB image based on the class array and colormap\n", "class_image = cmap(class_array)\n", @@ -104,7 +104,7 @@ "\n", "# Display the class image\n", "plt.imshow(class_image)\n", - "plt.axis('off')\n", + "plt.axis(\"off\")\n", "plt.show()" ] }, @@ -127,23 +127,25 @@ "\n", "scene = pbs.Scene()\n", "\n", - "# default cube, length 1 at origin \n", - "original_cube = pbs.make_box(np.array([0,0,0.5]), id=\"original_cube\")\n", - "original_cube.set_color([0,1,0])\n", + "# default cube, length 1 at origin\n", + "original_cube = pbs.make_box(np.array([0, 0, 0.5]), id=\"original_cube\")\n", + "original_cube.set_color([0, 1, 0])\n", "\n", - "# wide box, set scale after construction \n", - "second_cube = pbs.make_box(np.array([-2,0,0.5]), id=\"second_cube\")\n", - "second_cube.set_color([1,0,0])\n", - "second_cube.set_scale(np.array([2,1,1]))\n", + "# wide box, set scale after construction\n", + "second_cube = pbs.make_box(np.array([-2, 0, 0.5]), id=\"second_cube\")\n", + "second_cube.set_color([1, 0, 0])\n", + "second_cube.set_scale(np.array([2, 1, 1]))\n", "\n", "# tall box, can set scale at construction\n", - "third_cube = pbs.make_box(np.array([2,0,1]), np.array([1,1,2]), id=\"third_cube\")\n", - "third_cube.set_color([0,0,1])\n", + "third_cube = pbs.make_box(np.array([2, 0, 1]), np.array([1, 1, 2]), id=\"third_cube\")\n", + "third_cube.set_color([0, 0, 1])\n", "\n", - "# custom obj, large scale \n", - "diamond = pbs.make_body_from_obj(\"../assets/sample_objs/pyramid.obj\", np.array([0,2,2]), id=\"diamond\")\n", - "diamond.set_scale(np.array([6,3,1]))\n", - "diamond.set_color([.5,.5,.5])\n", + "# custom obj, large scale\n", + "diamond = pbs.make_body_from_obj(\n", + " \"../assets/sample_objs/pyramid.obj\", np.array([0, 2, 2]), id=\"diamond\"\n", + ")\n", + "diamond.set_scale(np.array([6, 3, 1]))\n", + "diamond.set_color([0.5, 0.5, 0.5])\n", "\n", "# add all bodies to scene\n", "scene.add_body(original_cube)\n", @@ -151,11 +153,11 @@ "scene.add_body(third_cube)\n", "scene.add_body(diamond)\n", "\n", - "# could also have added all at once \n", + "# could also have added all at once\n", "# scene.add_bodies([original_cube, second_cube, third_cube, diamond])\n", "\n", "image, depth, segm = scene.render(pbs.pybullet_render)\n", - "image\n" + "image" ] }, { @@ -174,33 +176,33 @@ "source": [ "import bayes3d.utils.pybullet_sim as pbs\n", "import bayes3d as b\n", - "import numpy as np \n", - "from PIL import Image as im \n", + "import numpy as np\n", + "from PIL import Image as im\n", "\n", - "# try rendering simple scene \n", + "# try rendering simple scene\n", "\n", - "scene = pbs.Scene() \n", + "scene = pbs.Scene()\n", "\n", - "path_to_obj = \"../assets/sample_objs/diamond.obj\" \n", - "position = np.array([-4,6,2])\n", - "diamond = pbs.make_body_from_obj(path_to_obj, position, id=\"diamond\")\n", - "diamond.set_color(np.array([0,1,0]))\n", - "diamond.set_scale(np.array([3,3,3]))\n", - "diamond.set_velocity(np.array([30,0,0]))\n", + "path_to_obj = \"../assets/sample_objs/diamond.obj\"\n", + "position = np.array([-4, 6, 2])\n", + "diamond = pbs.make_body_from_obj(path_to_obj, position, id=\"diamond\")\n", + "diamond.set_color(np.array([0, 1, 0]))\n", + "diamond.set_scale(np.array([3, 3, 3]))\n", + "diamond.set_velocity(np.array([30, 0, 0]))\n", "scene.add_body(diamond)\n", "\n", - "wall = pbs.make_box([0,0,2], [4,1,2], id = \"wall\")\n", - "wall.set_color(np.array([1,1,0]))\n", + "wall = pbs.make_box([0, 0, 2], [4, 1, 2], id=\"wall\")\n", + "wall.set_color(np.array([1, 1, 0]))\n", "wall.set_occluder(True)\n", "scene.add_body(wall)\n", "\n", - "scene.set_camera_position_target([0,-10,4], [0,0,0])\n", + "scene.set_camera_position_target([0, -10, 4], [0, 0, 0])\n", "\n", "pyb = scene.simulate(12)\n", - "pyb.create_gif('scene_gifs/kub_occlu.gif')\n", + "pyb.create_gif(\"scene_gifs/kub_occlu.gif\")\n", "rgb = pyb.frames[0]\n", "\n", - "# preview image \n", + "# preview image\n", "prev = im.fromarray(rgb)\n", "prev" ] @@ -211,29 +213,32 @@ "metadata": {}, "outputs": [], "source": [ - "# obtain default intrinsics, hardcoded for now \n", + "# obtain default intrinsics, hardcoded for now\n", "intriniscs = b.Intrinsics(720, 960, 500.0, 500.0, 320.0, 240.0, 0.1, 100.0)\n", "\n", - "# obtain filepaths for render_many \n", - "filepath_occ = '/home/ubuntu/bayes3d/assets/sample_objs/cube.obj'\n", - "filepath_diamond = '/home/ubuntu/bayes3d/assets/sample_objs/diamond.obj'\n", + "# obtain filepaths for render_many\n", + "filepath_occ = \"/home/ubuntu/bayes3d/assets/sample_objs/cube.obj\"\n", + "filepath_diamond = \"/home/ubuntu/bayes3d/assets/sample_objs/diamond.obj\"\n", "filepaths = [filepath_occ, filepath_diamond]\n", "print(filepaths)\n", "\n", - "# create poses for render_many \n", - "poses = pyb.get_body_poses() \n", - "wall_poses = np.array(poses['wall'])\n", - "diamond_poses = np.array(poses['diamond'])\n", - "poses = np.stack((wall_poses, diamond_poses), axis = 1) \n", + "# create poses for render_many\n", + "poses = pyb.get_body_poses()\n", + "wall_poses = np.array(poses[\"wall\"])\n", + "diamond_poses = np.array(poses[\"diamond\"])\n", + "poses = np.stack((wall_poses, diamond_poses), axis=1)\n", "print(poses.shape)\n", "\n", + "\n", "def get_camera_pose(view_matrix):\n", " # cam2world\n", " world2cam = np.array(view_matrix)\n", - " cam2world = np.linalg.inv(world2cam)\n", + " cam2world = np.linalg.inv(world2cam)\n", " return cam2world\n", - "# camera pose \n", - "cam_pose = get_camera_pose(np.reshape(np.array(pyb.viewMatrix), (4,4)).T )\n", + "\n", + "\n", + "# camera pose\n", + "cam_pose = get_camera_pose(np.reshape(np.array(pyb.viewMatrix), (4, 4)).T)\n", "print(cam_pose)\n", "\n", "mesh_scaling = [wall.scale, diamond.scale]\n", @@ -246,11 +251,20 @@ "metadata": {}, "outputs": [], "source": [ - "# render simple scene with render_many \n", + "# render simple scene with render_many\n", "from bayes3d._rendering.photorealistic_renderers.kubric_interface import render_many\n", - "outputs = render_many(filepaths, poses, intriniscs, mesh_scales = mesh_scaling, mesh_colors = mesh_colors, scaling_factor=1, camera_pose = cam_pose)\n", + "\n", + "outputs = render_many(\n", + " filepaths,\n", + " poses,\n", + " intriniscs,\n", + " mesh_scales=mesh_scaling,\n", + " mesh_colors=mesh_colors,\n", + " scaling_factor=1,\n", + " camera_pose=cam_pose,\n", + ")\n", "image = b.get_rgb_image(outputs[0].rgb)\n", - "image\n" + "image" ] }, { @@ -259,13 +273,19 @@ "metadata": {}, "outputs": [], "source": [ - "# test for longer scene \n", + "# test for longer scene\n", "images = [b.get_rgb_image(outputs[i].rgb) for i in range(len(outputs))]\n", "output_path = \"scene_gifs/kub_output.gif\"\n", "\n", "# Save the images as a GIF\n", - "images[0].save(output_path,\n", - " save_all=True, append_images=images[1:], optimize=False, duration=1000*(1/15), loop=0)" + "images[0].save(\n", + " output_path,\n", + " save_all=True,\n", + " append_images=images[1:],\n", + " optimize=False,\n", + " duration=1000 * (1 / 15),\n", + " loop=0,\n", + ")" ] }, { @@ -287,23 +307,23 @@ "\n", "path_to_mesh = \"../assets/sample_objs/icosahedron.obj\"\n", "\n", - "# create scene and positions \n", + "# create scene and positions\n", "scene = pbs.Scene()\n", - "d20_position = np.array([1, 0, 1.5]) \n", - "rotated_position = np.array([-1, 0, 1.5]) \n", - "d20 = pbs.make_body_from_obj(path_to_mesh, d20_position, id = \"d20\")\n", + "d20_position = np.array([1, 0, 1.5])\n", + "rotated_position = np.array([-1, 0, 1.5])\n", + "d20 = pbs.make_body_from_obj(path_to_mesh, d20_position, id=\"d20\")\n", "\n", "# create a rotated d20\n", - "rotated = pbs.make_body_from_obj(path_to_mesh, rotated_position, id = \"rotated\")\n", + "rotated = pbs.make_body_from_obj(path_to_mesh, rotated_position, id=\"rotated\")\n", "\n", - "# define the rotation matrix \n", + "# define the rotation matrix\n", "sample_rotation = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]])\n", "rotated.set_orientation(sample_rotation)\n", "\n", - "d20.set_color([1,0,1])\n", + "d20.set_color([1, 0, 1])\n", "scene.add_bodies([d20, rotated])\n", "image, depth, segm = scene.render(pbs.pybullet_render)\n", - "image\n" + "image" ] }, { @@ -326,16 +346,16 @@ "metadata": {}, "outputs": [], "source": [ - "# New Camera Pose \n", + "# New Camera Pose\n", "camera_pose = np.eye(4)\n", "\n", - "#camera facing downwards \n", + "# camera facing downwards\n", "camera_position = np.array([0, -10, 3])\n", - "camera_pose[:3,3] = camera_position\n", - "camera_orientation = np.array([[ -1.,0.,0.],\n", - " [0.,0.28734789,0.95782629],\n", - " [ 0.,-0.95782629,0.28734789]])\n", - "camera_pose[:3,:3] = camera_orientation\n", + "camera_pose[:3, 3] = camera_position\n", + "camera_orientation = np.array(\n", + " [[-1.0, 0.0, 0.0], [0.0, 0.28734789, 0.95782629], [0.0, -0.95782629, 0.28734789]]\n", + ")\n", + "camera_pose[:3, :3] = camera_orientation\n", "\n", "scene.set_camera_pose(camera_pose)\n", "image, depth, segm = scene.render(pbs.pybullet_render)\n", @@ -360,16 +380,16 @@ "\n", "# define scene and scene gravity, zero gravity by default\n", "scene = pbs.Scene()\n", - "scene.set_gravity([0,0,-10])\n", + "scene.set_gravity([0, 0, -10])\n", "\n", - "# create spheres \n", + "# create spheres\n", "sphere_position1 = [-1, 0, 1]\n", "sphere_start_velocity1 = [5, 0, 0]\n", "sphere_position2 = [1, 0, 1]\n", "sphere_start_velocity2 = [-5, 0, 0]\n", - "sphere1 = pbs.make_sphere(sphere_position1, 0.5,id = \"sphere1\")\n", + "sphere1 = pbs.make_sphere(sphere_position1, 0.5, id=\"sphere1\")\n", "sphere1.set_velocity(sphere_start_velocity1)\n", - "sphere1.set_color([0,1,1])\n", + "sphere1.set_color([0, 1, 1])\n", "sphere2 = pbs.make_sphere(sphere_position2, 0.5, \"sphere2\")\n", "sphere2.set_velocity(sphere_start_velocity2)\n", "\n", @@ -377,7 +397,7 @@ "scene.add_bodies([sphere1, sphere2])\n", "\n", "# simulate for 100 steps, which returns a PyBulletSim object\n", - "pyb_sim = scene.simulate(100) \n", + "pyb_sim = scene.simulate(100)\n", "\n", "# create a gif from the simulation\n", "pyb_sim.create_gif(\"scene_gifs/sphere_collision.gif\", 50)" @@ -389,7 +409,7 @@ "metadata": {}, "outputs": [], "source": [ - "# use pose information for other visualizers \n", + "# use pose information for other visualizers\n", "pyb_sim.get_body_poses()" ] }, @@ -421,28 +441,28 @@ "\n", "scene = pbs.Scene()\n", "\n", - "#create spheres of varying restitution\n", - "sphere1 = pbs.make_sphere(np.array([0, 0, 3]), 0.5, id = \"regular_ball\")\n", - "sphere1.set_color([1,0,0])\n", + "# create spheres of varying restitution\n", + "sphere1 = pbs.make_sphere(np.array([0, 0, 3]), 0.5, id=\"regular_ball\")\n", + "sphere1.set_color([1, 0, 0])\n", "sphere1.set_restitution(0.6)\n", - "sphere1.set_velocity([0,0,-2])\n", + "sphere1.set_velocity([0, 0, -2])\n", "\n", - "sphere2 = pbs.make_sphere(np.array([2, 0, 3]), 0.5, id = \"bouncy_ball\")\n", - "sphere2.set_color([0,1,0])\n", + "sphere2 = pbs.make_sphere(np.array([2, 0, 3]), 0.5, id=\"bouncy_ball\")\n", + "sphere2.set_color([0, 1, 0])\n", "sphere2.set_restitution(1)\n", - "sphere2.set_velocity([0,0,-2])\n", + "sphere2.set_velocity([0, 0, -2])\n", "\n", - "sphere3 = pbs.make_sphere(np.array([-2, 0, 3]), 0.5, id = \"flat_ball\")\n", - "sphere3.set_color([0,0,1])\n", + "sphere3 = pbs.make_sphere(np.array([-2, 0, 3]), 0.5, id=\"flat_ball\")\n", + "sphere3.set_color([0, 0, 1])\n", "sphere3.set_restitution(0)\n", - "sphere3.set_velocity([0,0,-2])\n", + "sphere3.set_velocity([0, 0, -2])\n", "\n", "scene.add_bodies([sphere1, sphere2, sphere3])\n", "\n", - "# Can set scene gravity \n", - "scene.set_gravity([0,0,-10])\n", + "# Can set scene gravity\n", + "scene.set_gravity([0, 0, -10])\n", "\n", - "# Can set fps of gif \n", + "# Can set fps of gif\n", "pyb_sim = scene.simulate(120)\n", "pyb_sim.create_gif(\"scene_gifs/restitution.gif\", 50)" ] @@ -464,24 +484,38 @@ "import bayes3d.utils.pybullet_sim as pbs\n", "import numpy as np\n", "\n", - "scene = pbs.Scene() \n", + "scene = pbs.Scene()\n", "path_to_d20 = \"../assets/sample_objs/icosahedron.obj\"\n", "\n", "x = -4\n", "\n", "# create spheres of varying friction\n", - "sphere = pbs.make_sphere(np.array([x,0,1]), scale = [1,1,1], id = \"regular_sphere\", friction = 1, velocity = [10,0,0])\n", - "d20 = pbs.make_body_from_obj(path_to_d20, np.array([x,1,1]), id = \"d20\", friction = 1, velocity = [10,0,0])\n", - "d20_fricitonless = pbs.make_body_from_obj(path_to_d20, np.array([x,2,1]), id = \"d20_fricitonless\", friction = 0, velocity = [10,0,0])\n", - "d20.set_color([1,0,1])\n", - "d20.set_scale = np.array([.6,.6,.6])\n", - "d20_fricitonless.set_color([0,1,1])\n", - "d20_fricitonless.set_scale = np.array([.6,.6,.6])\n", + "sphere = pbs.make_sphere(\n", + " np.array([x, 0, 1]),\n", + " scale=[1, 1, 1],\n", + " id=\"regular_sphere\",\n", + " friction=1,\n", + " velocity=[10, 0, 0],\n", + ")\n", + "d20 = pbs.make_body_from_obj(\n", + " path_to_d20, np.array([x, 1, 1]), id=\"d20\", friction=1, velocity=[10, 0, 0]\n", + ")\n", + "d20_fricitonless = pbs.make_body_from_obj(\n", + " path_to_d20,\n", + " np.array([x, 2, 1]),\n", + " id=\"d20_fricitonless\",\n", + " friction=0,\n", + " velocity=[10, 0, 0],\n", + ")\n", + "d20.set_color([1, 0, 1])\n", + "d20.set_scale = np.array([0.6, 0.6, 0.6])\n", + "d20_fricitonless.set_color([0, 1, 1])\n", + "d20_fricitonless.set_scale = np.array([0.6, 0.6, 0.6])\n", "\n", "\n", "scene.add_bodies([sphere, d20, d20_fricitonless])\n", - "scene.set_gravity([0,0,-10])\n", - "scene.set_camera_position_target([0,-10,10], [0,0,0])\n", + "scene.set_gravity([0, 0, -10])\n", + "scene.set_camera_position_target([0, -10, 10], [0, 0, 0])\n", "scene.set_downsampling(2)\n", "\n", "pyb_sim = scene.simulate(100)\n", @@ -514,10 +548,10 @@ "metadata": {}, "outputs": [], "source": [ - "# can track velocities of each object \n", - "velocities = pyb_sim.get_body_velocities() \n", + "# can track velocities of each object\n", + "velocities = pyb_sim.get_body_velocities()\n", "sphere_vel = velocities.get(\"regular_sphere\")\n", - "linear = [timestep['velocity'] for timestep in sphere_vel]\n", + "linear = [timestep[\"velocity\"] for timestep in sphere_vel]\n", "linear" ] }, @@ -561,17 +595,30 @@ "\n", "# d20 with angular velocity\n", "path_to_d20 = \"../assets/sample_objs/icosahedron.obj\"\n", - "d20 = pbs.make_body_from_obj(path_to_d20, np.array([0,0,2]), id = \"d20\", friction = 1, velocity = [0,0,0], angular_velocity = [0,50,0])\n", - "d20.set_color([1,0,1])\n", - "d20.set_scale = np.array([.5,.5,.5])\n", + "d20 = pbs.make_body_from_obj(\n", + " path_to_d20,\n", + " np.array([0, 0, 2]),\n", + " id=\"d20\",\n", + " friction=1,\n", + " velocity=[0, 0, 0],\n", + " angular_velocity=[0, 50, 0],\n", + ")\n", + "d20.set_color([1, 0, 1])\n", + "d20.set_scale = np.array([0.5, 0.5, 0.5])\n", "\n", "# sphere with angular velocity\n", - "sphere = pbs.make_sphere(np.array([2,2,2]), scale = [1,1,1], id = \"regular_sphere\", friction = 1, velocity = [0,0,0])\n", - "sphere.set_angular_velocity([50,50,0]) \n", + "sphere = pbs.make_sphere(\n", + " np.array([2, 2, 2]),\n", + " scale=[1, 1, 1],\n", + " id=\"regular_sphere\",\n", + " friction=1,\n", + " velocity=[0, 0, 0],\n", + ")\n", + "sphere.set_angular_velocity([50, 50, 0])\n", "\n", "# low gravity\n", "scene.add_bodies([sphere, d20])\n", - "scene.set_gravity([0,0,-5])\n", + "scene.set_gravity([0, 0, -5])\n", "scene.set_downsampling(3)\n", "\n", "pyb_sim = scene.simulate(90)\n", @@ -595,29 +642,53 @@ "import bayes3d.utils.pybullet_sim as pbs\n", "import numpy as np\n", "\n", - "# bouncing ball \n", - "scene = pbs.Scene() \n", - "ball = pbs.make_sphere(np.array([-3,-3,0]), scale = [1,1,1], id = \"ball\", friction = 0.1, velocity = [20,35,0])\n", - "scene.set_camera_position_target([0,-10,10], [0,0,0])\n", - "\n", - "# boundaries \n", - "wall1 = pbs.make_box(np.array([3,7,1]), scale = [10,.5,3], id = \"wall1\", friction = 0.1, velocity = [0,0,0])\n", + "# bouncing ball\n", + "scene = pbs.Scene()\n", + "ball = pbs.make_sphere(\n", + " np.array([-3, -3, 0]),\n", + " scale=[1, 1, 1],\n", + " id=\"ball\",\n", + " friction=0.1,\n", + " velocity=[20, 35, 0],\n", + ")\n", + "scene.set_camera_position_target([0, -10, 10], [0, 0, 0])\n", + "\n", + "# boundaries\n", + "wall1 = pbs.make_box(\n", + " np.array([3, 7, 1]),\n", + " scale=[10, 0.5, 3],\n", + " id=\"wall1\",\n", + " friction=0.1,\n", + " velocity=[0, 0, 0],\n", + ")\n", "wall1.set_mass(0)\n", - "wall1.set_color([1,1,0])\n", - "\n", - "wall2 = pbs.make_box(np.array([8,2,1]), scale = [.5,10,3], id=\"wall2\", friction = 0.1, velocity = [0,0,0])\n", + "wall1.set_color([1, 1, 0])\n", + "\n", + "wall2 = pbs.make_box(\n", + " np.array([8, 2, 1]),\n", + " scale=[0.5, 10, 3],\n", + " id=\"wall2\",\n", + " friction=0.1,\n", + " velocity=[0, 0, 0],\n", + ")\n", "wall2.set_mass(0)\n", - "wall2.set_color([1,1,0])\n", - "\n", - "wall3 = pbs.make_box(np.array([3,-3,1]), scale = [10,0.5,3], id=\"wall3\", friction = 0.1, velocity = [0,0,0])\n", + "wall2.set_color([1, 1, 0])\n", + "\n", + "wall3 = pbs.make_box(\n", + " np.array([3, -3, 1]),\n", + " scale=[10, 0.5, 3],\n", + " id=\"wall3\",\n", + " friction=0.1,\n", + " velocity=[0, 0, 0],\n", + ")\n", "wall3.set_mass(0)\n", - "wall3.set_color([1,1,0])\n", + "wall3.set_color([1, 1, 0])\n", "\n", "scene.add_bodies([ball, wall1, wall2, wall3])\n", - "scene.set_gravity([0,0,-10])\n", + "scene.set_gravity([0, 0, -10])\n", "\n", "# set pybullet to only record pose every 3rd timestep, useful for quickly rendering long simulations\n", - "scene.set_downsampling(3) \n", + "scene.set_downsampling(3)\n", "\n", "pyb_sim = scene.simulate(200)\n", "pyb_sim.create_gif(\"scene_gifs/minigolf.gif\", 30)" @@ -631,7 +702,7 @@ "source": [ "# compare to no downsampling\n", "scene.set_downsampling(1)\n", - "scene.set_timestep(1/60)\n", + "scene.set_timestep(1 / 60)\n", "pyb_sim = scene.simulate(200)\n", "pyb_sim.create_gif(\"scene_gifs/minigolf_full.gif\", 30)" ] @@ -644,7 +715,7 @@ "source": [ "# can also adjust timestep of simulation. default is 1/60 or 60hz\n", "scene.set_downsampling(1)\n", - "scene.set_timestep(1/120)\n", + "scene.set_timestep(1 / 120)\n", "pyb_sim = scene.simulate(200)\n", "pyb_sim.create_gif(\"scene_gifs/minigolf_fine.gif\", 30)" ] @@ -667,39 +738,47 @@ "import numpy as np\n", "\n", "# scale of spheres\n", - "scale = [2,2,2]\n", + "scale = [2, 2, 2]\n", "\n", - "# add spheres with random forces applied \n", + "# add spheres with random forces applied\n", "scene = pbs.Scene()\n", - "ball = pbs.make_sphere(np.array([-3,0,1]), scale = scale, id = \"ball\", friction = 0.1, velocity = [10,0,0])\n", - "ball.add_force([-1000,-1000,250], 10)\n", - "ball.add_force([1000,1000,-250], 40)\n", - "ball.add_force([-1000,-1000,-500], 70)\n", + "ball = pbs.make_sphere(\n", + " np.array([-3, 0, 1]), scale=scale, id=\"ball\", friction=0.1, velocity=[10, 0, 0]\n", + ")\n", + "ball.add_force([-1000, -1000, 250], 10)\n", + "ball.add_force([1000, 1000, -250], 40)\n", + "ball.add_force([-1000, -1000, -500], 70)\n", "scene.add_body(ball)\n", "\n", - "ball_1 = pbs.make_sphere(np.array([10,10,1]), scale = scale, id = \"ball_1\", friction = 0.1, velocity = [-10,0,0])\n", - "ball_1.add_force([1000,1000,250], 20)\n", - "ball_1.add_force([-1000,-1000,-250], 50)\n", - "ball_1.add_force([1000,1000,-500], 70)\n", - "ball_1.set_color([1,1,0])\n", + "ball_1 = pbs.make_sphere(\n", + " np.array([10, 10, 1]), scale=scale, id=\"ball_1\", friction=0.1, velocity=[-10, 0, 0]\n", + ")\n", + "ball_1.add_force([1000, 1000, 250], 20)\n", + "ball_1.add_force([-1000, -1000, -250], 50)\n", + "ball_1.add_force([1000, 1000, -500], 70)\n", + "ball_1.set_color([1, 1, 0])\n", "scene.add_body(ball_1)\n", "\n", - "ball_2 = pbs.make_sphere(np.array([0,3,1]), scale = scale, id = \"ball_2\", friction = 0.1, velocity = [0,-10,0])\n", - "ball_2.add_force([1000,1000,250], 15)\n", - "ball_2.add_force([1000,-1000,-250], 63)\n", - "ball_2.add_force([-1000,1000,-500], 77)\n", - "ball_2.set_color([1,0,1])\n", + "ball_2 = pbs.make_sphere(\n", + " np.array([0, 3, 1]), scale=scale, id=\"ball_2\", friction=0.1, velocity=[0, -10, 0]\n", + ")\n", + "ball_2.add_force([1000, 1000, 250], 15)\n", + "ball_2.add_force([1000, -1000, -250], 63)\n", + "ball_2.add_force([-1000, 1000, -500], 77)\n", + "ball_2.set_color([1, 0, 1])\n", "scene.add_body(ball_2)\n", "\n", - "ball_3 = pbs.make_sphere(np.array([0,-3,1]), scale = scale, id = \"ball_3\", friction = 0.1, velocity = [0,10,0])\n", - "ball_3.add_force([1000,1000,250], 20)\n", - "ball_3.add_force([1000,-1000,-250], 50)\n", - "ball_3.add_force([-1000,1000,-500], 70)\n", - "ball_3.set_color([0,1,1])\n", + "ball_3 = pbs.make_sphere(\n", + " np.array([0, -3, 1]), scale=scale, id=\"ball_3\", friction=0.1, velocity=[0, 10, 0]\n", + ")\n", + "ball_3.add_force([1000, 1000, 250], 20)\n", + "ball_3.add_force([1000, -1000, -250], 50)\n", + "ball_3.add_force([-1000, 1000, -500], 70)\n", + "ball_3.set_color([0, 1, 1])\n", "scene.add_body(ball_3)\n", "\n", - "scene.set_gravity([0,0,0])\n", - "scene.set_camera_position_target([0,-40,20], [0,0,0])\n", + "scene.set_gravity([0, 0, 0])\n", + "scene.set_camera_position_target([0, -40, 20], [0, 0, 0])\n", "scene.set_downsampling(5)\n", "pyb = scene.simulate(150)\n", "pyb.create_gif(\"scene_gifs/force.gif\", 30)" @@ -718,21 +797,27 @@ "\n", "# create ten balls with random colors and random velocity changes\n", "for i in range(10):\n", - " ball = pbs.make_sphere(np.array([-3,-5+2*i,1]), scale = [1,1,1], id = f\"ball{i}\", friction = 0.1, velocity = [10,0,0])\n", + " ball = pbs.make_sphere(\n", + " np.array([-3, -5 + 2 * i, 1]),\n", + " scale=[1, 1, 1],\n", + " id=f\"ball{i}\",\n", + " friction=0.1,\n", + " velocity=[10, 0, 0],\n", + " )\n", " ball.set_color(np.random.rand(3))\n", - " \n", - " # can add velocity changes at set times or velocities, or randomly \n", - " ball.add_velocity_change([-10,0,5], 50)\n", - " ball.add_velocity_change([10,0,5], 100 + np.random.randint(-10,10))\n", - " ball.add_velocity_change(np.random.randint(-20,20,3), 150)\n", + "\n", + " # can add velocity changes at set times or velocities, or randomly\n", + " ball.add_velocity_change([-10, 0, 5], 50)\n", + " ball.add_velocity_change([10, 0, 5], 100 + np.random.randint(-10, 10))\n", + " ball.add_velocity_change(np.random.randint(-20, 20, 3), 150)\n", "\n", " # can also add forces at set times or velocities, or randomly\n", - " ball.add_force([0,0,1000], np.random.randint(0,150))\n", + " ball.add_force([0, 0, 1000], np.random.randint(0, 150))\n", " scene.add_body(ball)\n", "\n", - "# set scene gravity and downsampling for faster rendering \n", - "scene.set_gravity([0,0,-10])\n", - "scene.set_camera_position_target([0,-25,15], [0,0,0])\n", + "# set scene gravity and downsampling for faster rendering\n", + "scene.set_gravity([0, 0, -10])\n", + "scene.set_camera_position_target([0, -25, 15], [0, 0, 0])\n", "scene.downsampling = 3\n", "\n", "pyb = scene.simulate(300)\n", @@ -758,20 +843,24 @@ "\n", "scene = pbs.Scene()\n", "\n", - "# can create forces over a duration \n", - "stationary_sphere = pbs.make_sphere(np.array([-1,0,1]), scale = [1,1,1], id = \"sphere\", velocity = [0,0,0])\n", - "reverse_gravity_sphere = pbs.make_sphere(np.array([1,0,1]), scale = [1,1,1], id = \"rev_sphere\", velocity = [0,0,0])\n", - "reverse_gravity_sphere.set_color([1,0,1])\n", + "# can create forces over a duration\n", + "stationary_sphere = pbs.make_sphere(\n", + " np.array([-1, 0, 1]), scale=[1, 1, 1], id=\"sphere\", velocity=[0, 0, 0]\n", + ")\n", + "reverse_gravity_sphere = pbs.make_sphere(\n", + " np.array([1, 0, 1]), scale=[1, 1, 1], id=\"rev_sphere\", velocity=[0, 0, 0]\n", + ")\n", + "reverse_gravity_sphere.set_color([1, 0, 1])\n", "\n", - "# add force from 0 to 40th timestep \n", - "reverse_gravity_sphere.add_force([0,0,10], 0, end_timestep = 40)\n", + "# add force from 0 to 40th timestep\n", + "reverse_gravity_sphere.add_force([0, 0, 10], 0, end_timestep=40)\n", "\n", - "# simulate \n", + "# simulate\n", "scene.add_bodies([stationary_sphere, reverse_gravity_sphere])\n", "scene.downsampling = 3\n", - "scene.set_camera_position_target([0,-10,15], [0,0,0])\n", + "scene.set_camera_position_target([0, -10, 15], [0, 0, 0])\n", "pyb = scene.simulate(100)\n", - "pyb.create_gif(\"scene_gifs/force_duration.gif\", 30) \n" + "pyb.create_gif(\"scene_gifs/force_duration.gif\", 30)" ] }, { @@ -786,17 +875,21 @@ "scene = pbs.Scene()\n", "\n", "# can create forces every nth timestep\n", - "stationary_sphere = pbs.make_sphere(np.array([-1,0,1]), scale = [1,1,1], id = \"sphere\", velocity = [0,0,0])\n", - "nudged_sphere = pbs.make_sphere(np.array([1,0,1]), scale = [1,1,1], id = \"n_sphere\", velocity = [0,0,0])\n", - "nudged_sphere.set_color([1,0,1])\n", + "stationary_sphere = pbs.make_sphere(\n", + " np.array([-1, 0, 1]), scale=[1, 1, 1], id=\"sphere\", velocity=[0, 0, 0]\n", + ")\n", + "nudged_sphere = pbs.make_sphere(\n", + " np.array([1, 0, 1]), scale=[1, 1, 1], id=\"n_sphere\", velocity=[0, 0, 0]\n", + ")\n", + "nudged_sphere.set_color([1, 0, 1])\n", "\n", "# add force every 60th timestep\n", - "nudged_sphere.add_force([0,0,400], 0, 300, step = 60)\n", + "nudged_sphere.add_force([0, 0, 400], 0, 300, step=60)\n", "\n", "# simulate\n", "scene.add_bodies([stationary_sphere, nudged_sphere])\n", - "scene.set_gravity([0,0,-10])\n", - "scene.set_camera_position_target([0,-10,15], [0,0,0])\n", + "scene.set_gravity([0, 0, -10])\n", + "scene.set_camera_position_target([0, -10, 15], [0, 0, 0])\n", "scene.downsampling = 3\n", "pyb = scene.simulate(400)\n", "pyb.create_gif(\"scene_gifs/force_step.gif\", 30)" @@ -813,17 +906,19 @@ "\n", "scene = pbs.Scene()\n", "\n", - "for i in range(20): \n", - " position = np.random.randint(-5,5,3) + np.array([0,0,6])\n", - " fairy = pbs.make_sphere(position, scale = [1,1,1], id = f\"fairy{i}\", velocity = [0,0,0])\n", + "for i in range(20):\n", + " position = np.random.randint(-5, 5, 3) + np.array([0, 0, 6])\n", + " fairy = pbs.make_sphere(\n", + " position, scale=[1, 1, 1], id=f\"fairy{i}\", velocity=[0, 0, 0]\n", + " )\n", " fairy.set_color(np.random.rand(3))\n", - " for j in range(0,200,10):\n", - " vel_change = np.random.randint(0,2,3) * 10 - 5\n", + " for j in range(0, 200, 10):\n", + " vel_change = np.random.randint(0, 2, 3) * 10 - 5\n", " fairy.add_velocity_change(vel_change, j)\n", " scene.add_body(fairy)\n", "\n", - "scene.set_gravity([0,0,0])\n", - "scene.set_camera_position_target([0,-20,15], [0,0,0])\n", + "scene.set_gravity([0, 0, 0])\n", + "scene.set_camera_position_target([0, -20, 15], [0, 0, 0])\n", "scene.downsampling = 3\n", "pyb = scene.simulate(300)\n", "pyb.create_gif(\"scene_gifs/brownian.gif\", 30)" @@ -857,21 +952,33 @@ "\n", "scene = pbs.Scene()\n", "\n", - "# add a ton of balls \n", - "for i in range(10): \n", - " ball = pbs.make_sphere(np.array([-10,-5+i*(1.5),1]), scale = [1,1,1], id = f\"ball{i}\", friction = 0.1, velocity = [8,0,0])\n", - " ball.add_force([0,0,600], 60 + i*5)\n", + "# add a ton of balls\n", + "for i in range(10):\n", + " ball = pbs.make_sphere(\n", + " np.array([-10, -5 + i * (1.5), 1]),\n", + " scale=[1, 1, 1],\n", + " id=f\"ball{i}\",\n", + " friction=0.1,\n", + " velocity=[8, 0, 0],\n", + " )\n", + " ball.add_force([0, 0, 600], 60 + i * 5)\n", " ball.set_color([np.random.rand(3)])\n", " scene.add_body(ball)\n", "\n", "\n", - "wall = pbs.make_box(np.array([10,0,2.5]), scale = [1,20,5], id = \"wall\", friction = 0.1, velocity = [0,0,0])\n", + "wall = pbs.make_box(\n", + " np.array([10, 0, 2.5]),\n", + " scale=[1, 20, 5],\n", + " id=\"wall\",\n", + " friction=0.1,\n", + " velocity=[0, 0, 0],\n", + ")\n", "wall.set_mass(500)\n", - "wall.set_color([1,.5,1])\n", + "wall.set_color([1, 0.5, 1])\n", "scene.add_body(wall)\n", "\n", - "scene.set_gravity([0,0,-10])\n", - "scene.set_camera_position_target([-3,-20,15], [0,0,0])\n", + "scene.set_gravity([0, 0, -10])\n", + "scene.set_camera_position_target([-3, -20, 15], [0, 0, 0])\n", "scene.set_downsampling(4)\n", "\n", "pyb = scene.simulate(300)\n", @@ -893,17 +1000,19 @@ "source": [ "import matplotlib.pyplot as plt\n", "\n", - "def extract_linear_vel(pyb): \n", - " velocities = pyb.get_body_velocities() \n", + "\n", + "def extract_linear_vel(pyb):\n", + " velocities = pyb.get_body_velocities()\n", " linear_vel = {}\n", " for body in velocities.keys():\n", " body_vel = velocities[body]\n", - " raw = [] \n", - " for i in range(300//4): \n", - " raw.append(body_vel[i]['velocity'])\n", + " raw = []\n", + " for i in range(300 // 4):\n", + " raw.append(body_vel[i][\"velocity\"])\n", " linear_vel[body] = raw\n", " return linear_vel\n", "\n", + "\n", "linear_vel = extract_linear_vel(pyb)" ] }, @@ -920,14 +1029,16 @@ " for obj_name, velocities in velocity_dict.items():\n", " # Assume velocities are tuples/lists in the form (x, y, z)\n", " velocities = np.array(velocities) # Convert to numpy array for convenience\n", - " velocity_magnitudes = np.linalg.norm(velocities, axis=1) # Calculate the norm (magnitude) of the velocities\n", + " velocity_magnitudes = np.linalg.norm(\n", + " velocities, axis=1\n", + " ) # Calculate the norm (magnitude) of the velocities\n", " plt.plot(time, velocity_magnitudes, label=obj_name)\n", "\n", - " plt.xlabel('Time')\n", - " plt.ylabel('Velocity')\n", - " plt.title('Velocity Magnitude over Time')\n", + " plt.xlabel(\"Time\")\n", + " plt.ylabel(\"Velocity\")\n", + " plt.title(\"Velocity Magnitude over Time\")\n", " plt.legend()\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -945,30 +1056,33 @@ "metadata": {}, "outputs": [], "source": [ - "import copy \n", + "import copy\n", + "\n", "linear_vel_mag = copy.deepcopy(linear_vel)\n", - "for body in linear_vel.keys(): \n", - " linear_vel_mag[body] = [] \n", - " for time in range(len(linear_vel[body])): \n", + "for body in linear_vel.keys():\n", + " linear_vel_mag[body] = []\n", + " for time in range(len(linear_vel[body])):\n", " linear_vel_mag[body].append(np.linalg.norm(linear_vel[body][time]))\n", "\n", "import numpy as np\n", "\n", - "def convert_seg_to_vel(pyb, velocities): \n", - " segm = pyb.segm \n", - " vels = segm.copy() \n", - " for time, frame in enumerate(vels): \n", - " for idx, pyb_id in np.ndenumerate(frame): \n", - " # map pyb_id to body_id to velocity \n", - " if pyb_id not in {0,-1}:\n", + "\n", + "def convert_seg_to_vel(pyb, velocities):\n", + " segm = pyb.segm\n", + " vels = segm.copy()\n", + " for time, frame in enumerate(vels):\n", + " for idx, pyb_id in np.ndenumerate(frame):\n", + " # map pyb_id to body_id to velocity\n", + " if pyb_id not in {0, -1}:\n", " body_id = pyb.pyb_id_to_body_id[pyb_id]\n", " magnitude = velocities[body_id][time]\n", " frame[idx] = magnitude # Update the velocity for the specific id\n", - " else: \n", - " frame[idx] = 0 \n", + " else:\n", + " frame[idx] = 0\n", "\n", " return vels\n", "\n", + "\n", "vels = convert_seg_to_vel(pyb, linear_vel_mag)" ] }, @@ -982,39 +1096,40 @@ "import numpy as np\n", "import imageio\n", "\n", + "\n", "def create_heatmap_gif(data, filename):\n", " # Create a list to store each frame of the GIF\n", " frames = []\n", - " \n", + "\n", " # Define the colormap\n", - " cmap = plt.get_cmap('viridis')\n", + " cmap = plt.get_cmap(\"viridis\")\n", "\n", " # Loop over each array in the data\n", " for array in data:\n", " # Create a figure and axes\n", " fig, ax = plt.subplots()\n", - " \n", + "\n", " # Create the heatmap for the current array\n", " heatmap = ax.imshow(array, cmap=cmap)\n", - " \n", + "\n", " # Remove the axis\n", - " ax.axis('off')\n", + " ax.axis(\"off\")\n", "\n", " # Draw the figure and retrieve the pixel data\n", " fig.canvas.draw()\n", - " image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')\n", + " image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=\"uint8\")\n", "\n", " # Reshape the image to the correct dimensions\n", " image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n", "\n", " # Append the image to the frames list\n", " frames.append(image)\n", - " \n", + "\n", " # Close the figure to save memory\n", " plt.close(fig)\n", "\n", " # Save the frames as a GIF\n", - " imageio.mimsave(filename, frames, 'GIF', duration = 0.1)\n" + " imageio.mimsave(filename, frames, \"GIF\", duration=0.1)" ] }, { @@ -1023,7 +1138,7 @@ "metadata": {}, "outputs": [], "source": [ - "create_heatmap_gif(vels, 'scene_gifs/heatmat.gif')" + "create_heatmap_gif(vels, \"scene_gifs/heatmat.gif\")" ] }, { @@ -1051,23 +1166,23 @@ "import bayes3d.utils.pybullet_sim as pbs\n", "import numpy as np\n", "\n", - "# scene of an object partially occluded by a wall \n", - "scene = pbs.Scene() \n", + "# scene of an object partially occluded by a wall\n", + "scene = pbs.Scene()\n", "scene.set_downsampling(3)\n", "\n", - "path_to_obj = \"../assets/sample_objs/diamond.obj\" \n", - "position = np.array([-4,6,2])\n", + "path_to_obj = \"../assets/sample_objs/diamond.obj\"\n", + "position = np.array([-4, 6, 2])\n", "diamond = pbs.make_body_from_obj(path_to_obj, position, id=\"diamond\")\n", - "diamond.set_color(np.array([0,1,0]))\n", - "diamond.set_scale(np.array([3,3,3]))\n", + "diamond.set_color(np.array([0, 1, 0]))\n", + "diamond.set_scale(np.array([3, 3, 3]))\n", "scene.add_body(diamond)\n", "\n", - "wall = pbs.make_box([0,0,2], [4,1,2], id = \"wall\")\n", - "wall.set_color(np.array([1,1,0]))\n", + "wall = pbs.make_box([0, 0, 2], [4, 1, 2], id=\"wall\")\n", + "wall.set_color(np.array([1, 1, 0]))\n", "wall.set_occluder(True)\n", "scene.add_body(wall)\n", "\n", - "scene.set_camera_position_target([0,-10,4], [0,0,0])\n", + "scene.set_camera_position_target([0, -10, 4], [0, 0, 0])\n", "\n", "pyb = scene.simulate(12)" ] @@ -1086,13 +1201,13 @@ "metadata": {}, "outputs": [], "source": [ - "# multirender \n", - "# check if object is occlduer \n", - "# if it isn't, add to mesh \n", + "# multirender\n", + "# check if object is occlduer\n", + "# if it isn't, add to mesh\n", "\n", - "poses = pyb.get_body_poses() \n", + "poses = pyb.get_body_poses()\n", "\n", - "for body_id in scene.bodies: \n", + "for body_id in scene.bodies:\n", " body = scene.bodies[body_id]\n", " print(body.file_dir)\n", " print(len(poses[body_id]))" diff --git a/scripts/experiments/mcs/otp_gen/otp_gen/tracking.ipynb b/scripts/experiments/mcs/otp_gen/otp_gen/tracking.ipynb index 49a4a51f..c59509aa 100644 --- a/scripts/experiments/mcs/otp_gen/otp_gen/tracking.ipynb +++ b/scripts/experiments/mcs/otp_gen/otp_gen/tracking.ipynb @@ -36,43 +36,53 @@ "import matplotlib.animation as animation\n", "import matplotlib.pyplot as plt\n", "from IPython.display import HTML\n", + "\n", "print(os.getcwd(), \"this is the current working directory\")\n", "\n", + "\n", "def display_video(frames, framerate=30):\n", " height, width, _ = frames[0].shape\n", " dpi = 70\n", " orig_backend = matplotlib.get_backend()\n", - " matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering.\n", + " matplotlib.use(\"Agg\") # Switch to headless 'Agg' to inhibit figure rendering.\n", " fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)\n", " matplotlib.use(orig_backend) # Switch back to the original backend.\n", " ax.set_axis_off()\n", - " ax.set_aspect('equal')\n", + " ax.set_aspect(\"equal\")\n", " ax.set_position([0, 0, 1, 1])\n", " im = ax.imshow(frames[0])\n", + "\n", " def update(frame):\n", - " im.set_data(frame)\n", - " return [im]\n", - " interval = 1000/framerate\n", - " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n", - " interval=interval, blit=True, repeat=True)\n", + " im.set_data(frame)\n", + " return [im]\n", + "\n", + " interval = 1000 / framerate\n", + " anim = animation.FuncAnimation(\n", + " fig=fig, func=update, frames=frames, interval=interval, blit=True, repeat=True\n", + " )\n", " return HTML(anim.to_html5_video())\n", "\n", + "\n", "def object_pose_in_camera_frame(object_id, view_matrix):\n", - " object_pos, object_orn = p.getBasePositionAndOrientation(object_id) # world frame\n", - " world2cam = np.array(view_matrix).reshape([4,4]).T # world --> cam \n", + " object_pos, object_orn = p.getBasePositionAndOrientation(object_id) # world frame\n", + " world2cam = np.array(view_matrix).reshape([4, 4]).T # world --> cam\n", " object_transform_matrix = np.eye(4)\n", - " object_transform_matrix[:3, :3] = np.reshape(p.getMatrixFromQuaternion(object_orn), (3, 3))\n", + " object_transform_matrix[:3, :3] = np.reshape(\n", + " p.getMatrixFromQuaternion(object_orn), (3, 3)\n", + " )\n", " object_transform_matrix[:3, 3] = object_pos\n", " return world2cam @ object_transform_matrix\n", "\n", + "\n", "def get_camera_pose(view_matrix):\n", " # cam2world\n", " world2cam = np.array(view_matrix)\n", - " cam2world = np.linalg.inv(world2cam)\n", + " cam2world = np.linalg.inv(world2cam)\n", " return cam2world\n", "\n", + "\n", "def object_pose_in_camera_frame(object_pose, view_matrix):\n", - " world2cam = np.array(view_matrix).reshape([4,4]) # world --> cam \n", + " world2cam = np.array(view_matrix).reshape([4, 4]) # world --> cam\n", " return world2cam @ object_pose" ] }, @@ -89,34 +99,42 @@ "metadata": {}, "outputs": [], "source": [ - "import bayes3d.utils.pybullet_sim as pyb \n", + "import bayes3d.utils.pybullet_sim as pyb\n", "\n", "scene = pyb.Scene()\n", "scene.set_gravity([0, 0, -9.8])\n", - "scene.set_timestep(1/240)\n", - "scene.set_downsampling(4) \n", + "scene.set_timestep(1 / 240)\n", + "scene.set_downsampling(4)\n", "\n", - "occ1_meshscale = [0.07,0.07,0.07]\n", + "occ1_meshscale = [0.07, 0.07, 0.07]\n", "path_to_obj = \"../assets/sample_objs/plane.obj\"\n", - "base_position = [1,-1,1]\n", + "base_position = [1, -1, 1]\n", "base_orientation = [0.7071068, 0, 0, 0.7071068]\n", - "base_orientation = np.array(p.getMatrixFromQuaternion(base_orientation)).reshape(3,3)\n", - "occ1 = pyb.make_body_from_obj(path_to_obj, base_position, orientation=base_orientation,scale = occ1_meshscale, id = \"occluder\")\n", + "base_orientation = np.array(p.getMatrixFromQuaternion(base_orientation)).reshape(3, 3)\n", + "occ1 = pyb.make_body_from_obj(\n", + " path_to_obj,\n", + " base_position,\n", + " orientation=base_orientation,\n", + " scale=occ1_meshscale,\n", + " id=\"occluder\",\n", + ")\n", "occ1.set_mass(0)\n", - "occ1.set_color([0.5,0.5,0.5])\n", + "occ1.set_color([0.5, 0.5, 0.5])\n", "scene.add_body(occ1)\n", "\n", "box_mass = 1\n", "path_to_box = \"../assets/sample_objs/cube.obj\"\n", "box_position = [-3.25, 0, 0.501]\n", "box_start_velocity = [6, 0, 6]\n", - "mesh_scale = [0.5,0.5,0.5]\n", - "box = pyb.make_body_from_obj(path_to_box, box_position, scale = mesh_scale, restitution=1, id = \"d20\")\n", + "mesh_scale = [0.5, 0.5, 0.5]\n", + "box = pyb.make_body_from_obj(\n", + " path_to_box, box_position, scale=mesh_scale, restitution=1, id=\"d20\"\n", + ")\n", "box.set_velocity(box_start_velocity)\n", "scene.add_body(box)\n", "\n", - "bull = scene.simulate(360, defaultView = True)\n", - "bull.create_gif(\"tracking.gif\")\n" + "bull = scene.simulate(360, defaultView=True)\n", + "bull.create_gif(\"tracking.gif\")" ] }, { @@ -130,13 +148,13 @@ "plane_ori = p.getMatrixFromQuaternion(plane_ori)\n", "plane_cam_pose = np.eye(4)\n", "plane_cam_pose[:3, 3] = plane_pos\n", - "plane_cam_pose[:3, :3] = np.array(plane_ori).reshape(3,3)\n", - "view_matrix = bull.viewMatrix \n", - "view_matrix = np.array(view_matrix).reshape([4,4]).T\n", - "box_poses = [view_matrix@pose for pose in box_poses]\n", - "plane_cam_poses = view_matrix@plane_cam_pose\n", - "occ1_pose = view_matrix@ bull.get_body_poses()['occluder'][0]\n", - "cam_pose = get_camera_pose(view_matrix) " + "plane_cam_pose[:3, :3] = np.array(plane_ori).reshape(3, 3)\n", + "view_matrix = bull.viewMatrix\n", + "view_matrix = np.array(view_matrix).reshape([4, 4]).T\n", + "box_poses = [view_matrix @ pose for pose in box_poses]\n", + "plane_cam_poses = view_matrix @ plane_cam_pose\n", + "occ1_pose = view_matrix @ bull.get_body_poses()[\"occluder\"][0]\n", + "cam_pose = get_camera_pose(view_matrix)" ] }, { @@ -153,14 +171,14 @@ "outputs": [], "source": [ "array_dict = {\n", - " 'box': box_poses, # box poses in camera view \n", - " 'plane': plane_cam_pose, # plane pose in camera view \n", - " 'occ1' : occ1_pose, # occluder pose in camera view\n", - " 'occ1_meshscale' : occ1_meshscale,\n", - " 'cam_pose' : cam_pose, # camera pose in world view\n", + " \"box\": box_poses, # box poses in camera view\n", + " \"plane\": plane_cam_pose, # plane pose in camera view\n", + " \"occ1\": occ1_pose, # occluder pose in camera view\n", + " \"occ1_meshscale\": occ1_meshscale,\n", + " \"cam_pose\": cam_pose, # camera pose in world view\n", "}\n", "\n", - "np.savez('scene_npzs/default_view_scene_demo.npz', **array_dict)" + "np.savez(\"scene_npzs/default_view_scene_demo.npz\", **array_dict)" ] }, { @@ -184,7 +202,7 @@ } ], "source": [ - "# sanity checks \n", + "# sanity checks\n", "\n", "print(occ1_pose)\n", "print(cam_pose)" @@ -205,13 +223,14 @@ "source": [ "import bayes3d as b\n", "import bayes3d.transforms_3d as t3d\n", - "import numpy as np \n", + "import numpy as np\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import jax\n", "from jax.debug import print as jprint\n", "import physics_priors as p\n", "import importlib\n", + "\n", "importlib.reload(p)\n", "import time\n", "import PIL.Image\n", @@ -220,6 +239,7 @@ "import matplotlib.pyplot as plt\n", "from IPython.display import HTML\n", "import plotly.graph_objects as go\n", + "\n", "%matplotlib inline " ] }, @@ -247,9 +267,12 @@ "intrinsics = b.Intrinsics(\n", " height=360,\n", " width=480,\n", - " fx=180*sqrt(3), fy=180*sqrt(3),\n", - " cx=240.0, cy=180.0,\n", - " near=0.1, far=10.0\n", + " fx=180 * sqrt(3),\n", + " fy=180 * sqrt(3),\n", + " cx=240.0,\n", + " cy=180.0,\n", + " near=0.1,\n", + " far=10.0,\n", ")\n", "b.setup_renderer(intrinsics)" ] @@ -260,16 +283,16 @@ "metadata": {}, "outputs": [], "source": [ - "loaded_poses = np.load('scene_npzs/velchange_scene.npz')\n", + "loaded_poses = np.load(\"scene_npzs/velchange_scene.npz\")\n", "# Find number of timesteps\n", - "N_tsteps = loaded_poses['box'].shape[0]\n", + "N_tsteps = loaded_poses[\"box\"].shape[0]\n", "\n", "# load occluder poses into (N,4,4)\n", - "occ1_meshscale = loaded_poses['occ1_meshscale']\n", + "occ1_meshscale = loaded_poses[\"occ1_meshscale\"]\n", "# occ1_meshscale = [0.0667,0.0667,0.0667]\n", - "occ1_pose = loaded_poses['occ1']\n", + "occ1_pose = loaded_poses[\"occ1\"]\n", "occ1_pose[1:3] *= -1\n", - "occ1_poses = jnp.tile(jnp.array(occ1_pose), (N_tsteps,1,1))\n", + "occ1_poses = jnp.tile(jnp.array(occ1_pose), (N_tsteps, 1, 1))\n", "\n", "# occ2_meshscale = loaded_poses['occ2_meshscale']\n", "# occ2_pose = loaded_poses['occ2']\n", @@ -277,21 +300,27 @@ "# occ2_poses = jnp.tile(jnp.array(occ2_pose), (N_tsteps,1,1))\n", "\n", "# get object poses\n", - "gt_poses = loaded_poses['box']\n", - "gt_poses[:,1:3,:] *= -1 # CV2 convention\n", + "gt_poses = loaded_poses[\"box\"]\n", + "gt_poses[:, 1:3, :] *= -1 # CV2 convention\n", "gt_poses = jnp.array(gt_poses)\n", "\n", "# combine the poses\n", "# total_gt_poses = jnp.stack([gt_poses, occ1_poses, occ2_poses], axis = 1)\n", - "total_gt_poses = jnp.stack([gt_poses, occ1_poses], axis = 1)\n", + "total_gt_poses = jnp.stack([gt_poses, occ1_poses], axis=1)\n", "\n", "# cv2 convention of cam pose\n", - "world2cam = np.linalg.inv(loaded_poses['cam_pose'])\n", + "world2cam = np.linalg.inv(loaded_poses[\"cam_pose\"])\n", "world2cam[1:3] *= -1\n", - "cam_pose = jnp.linalg.inv(jnp.array(world2cam)) \n", + "cam_pose = jnp.linalg.inv(jnp.array(world2cam))\n", "\n", - "b.RENDERER.add_mesh_from_file(\"../assets/sample_objs/cube.obj\",mesh_name=\"cube_0\", scaling_factor=[0.5,0.5,0.5])\n", - "b.RENDERER.add_mesh_from_file(\"../assets/sample_objs/plane.obj\",mesh_name=\"occluder1\", scaling_factor=occ1_meshscale)" + "b.RENDERER.add_mesh_from_file(\n", + " \"../assets/sample_objs/cube.obj\", mesh_name=\"cube_0\", scaling_factor=[0.5, 0.5, 0.5]\n", + ")\n", + "b.RENDERER.add_mesh_from_file(\n", + " \"../assets/sample_objs/plane.obj\",\n", + " mesh_name=\"occluder1\",\n", + " scaling_factor=occ1_meshscale,\n", + ")" ] }, { @@ -305,10 +334,12 @@ "\n", "# gt_images = b.RENDERER.render_multiobject_parallel(total_gt_poses, [0,1,2])\n", "\n", - "gt_images = b.RENDERER.render_many(total_gt_poses,jnp.array([0,1]))\n", + "gt_images = b.RENDERER.render_many(total_gt_poses, jnp.array([0, 1]))\n", "\n", "# gt_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(gt_poses[:,None, ...], jnp.array([0]))\n", - "depths = [b.viz.get_depth_image(gt_images[i,:,:,2]) for i in range(gt_images.shape[0])]" + "depths = [\n", + " b.viz.get_depth_image(gt_images[i, :, :, 2]) for i in range(gt_images.shape[0])\n", + "]" ] }, { @@ -318,26 +349,26 @@ "outputs": [], "source": [ "dx, dy, dz = 0.2, 0.2, 0.2\n", - "translation_deltas = b.utils.make_translation_grid_enumeration(-dx, -dy, -dz, dx, dy, dz, 5,5,5)\n", + "translation_deltas = b.utils.make_translation_grid_enumeration(\n", + " -dx, -dy, -dz, dx, dy, dz, 5, 5, 5\n", + ")\n", "\n", "\n", "dx, dy, dz = 2, 2, 2\n", "gridding = [\n", + " b.utils.make_translation_grid_enumeration(-dx, -dy, -dz, dx, dy, dz, 5, 5, 5),\n", " b.utils.make_translation_grid_enumeration(\n", - " -dx, -dy, -dz, dx, dy, dz, 5,5,5\n", + " -dx / 2.0, -dy / 2, -dz / 2, dx / 2, dy / 2, dz / 2, 5, 5, 5\n", " ),\n", " b.utils.make_translation_grid_enumeration(\n", - " -dx/2.0, -dy/2, -dz/2, dx/2, dy/2, dz/2, 5,5,5\n", - " ),\n", - " b.utils.make_translation_grid_enumeration(\n", - " -dx/10.0, -dy/10, -dz/10, dx/10, dy/10, dz/10, 5,5,5\n", + " -dx / 10.0, -dy / 10, -dz / 10, dx / 10, dy / 10, dz / 10, 5, 5, 5\n", " ),\n", "]\n", "\n", "key = jax.random.PRNGKey(314)\n", - "rotation_deltas = jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0))(\n", - " jax.random.split(key, 100)\n", - ")\n", + "rotation_deltas = jax.vmap(\n", + " lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0)\n", + ")(jax.random.split(key, 100))\n", "\n", "# len_proposals = gridding[0].shape[0]\n", "len_proposals = translation_deltas.shape[0]\n", @@ -356,36 +387,48 @@ "outputs": [], "source": [ "def update_pose_estimate(memory, gt_image):\n", - "\n", " pose_memory, T = memory\n", - " old_pose_estimate = pose_memory[T,...]\n", - " prev_pose = pose_memory[T-1,...]\n", + " old_pose_estimate = pose_memory[T, ...]\n", + " prev_pose = pose_memory[T - 1, ...]\n", "\n", " threedp3_weight = 1\n", " proposals = jnp.einsum(\"ij,ajk->aik\", old_pose_estimate, translation_deltas)\n", - " rendered_images = b.RENDERER.render_many(jnp.stack([proposals, occ1_poses_trans],axis = 1), jnp.array([0,1]))\n", - " threedp3_scores = threedp3_weight * b.threedp3_likelihood_parallel(gt_image, rendered_images, 0.001, 0.1, 10**3, 3)\n", - " unique_best_3dp3_score = jnp.sum(threedp3_scores == threedp3_scores[jnp.argmax(threedp3_scores)]) == 1\n", + " rendered_images = b.RENDERER.render_many(\n", + " jnp.stack([proposals, occ1_poses_trans], axis=1), jnp.array([0, 1])\n", + " )\n", + " threedp3_scores = threedp3_weight * b.threedp3_likelihood_parallel(\n", + " gt_image, rendered_images, 0.001, 0.1, 10**3, 3\n", + " )\n", + " unique_best_3dp3_score = (\n", + " jnp.sum(threedp3_scores == threedp3_scores[jnp.argmax(threedp3_scores)]) == 1\n", + " )\n", "\n", - " physics_weight = jax.lax.cond(unique_best_3dp3_score, lambda _ : 5000, lambda _ : 10000, None)\n", + " physics_weight = jax.lax.cond(\n", + " unique_best_3dp3_score, lambda _: 5000, lambda _: 10000, None\n", + " )\n", " # physics_weight = 0\n", "\n", - " physics_estimated_pose = p.physics_prior_v1_jit(old_pose_estimate, prev_pose, jnp.array([1,1,1]), cam_pose, world2cam)\n", + " physics_estimated_pose = p.physics_prior_v1_jit(\n", + " old_pose_estimate, prev_pose, jnp.array([1, 1, 1]), cam_pose, world2cam\n", + " )\n", "\n", - " physics_scores = jax.lax.cond(jnp.greater(T, 1), \n", - " lambda _ : physics_weight * p.physics_prior_parallel_jit(proposals, physics_estimated_pose), \n", - " lambda _ : jnp.zeros(threedp3_scores.shape[0]), \n", - " None)\n", + " physics_scores = jax.lax.cond(\n", + " jnp.greater(T, 1),\n", + " lambda _: physics_weight\n", + " * p.physics_prior_parallel_jit(proposals, physics_estimated_pose),\n", + " lambda _: jnp.zeros(threedp3_scores.shape[0]),\n", + " None,\n", + " )\n", "\n", " scores = threedp3_scores + physics_scores\n", "\n", " pose_estimate = proposals[jnp.argmax(scores)]\n", - " pose_memory = pose_memory.at[T+1,...].set(pose_estimate)\n", + " pose_memory = pose_memory.at[T + 1, ...].set(pose_estimate)\n", "\n", " pose_world = cam_pose @ pose_estimate\n", - " gt_pose_world = cam_pose @ gt_poses[T-1]\n", - " jprint(\"{}: {}, {}\", T, unique_best_3dp3_score, pose_world[:3,3])\n", - " jprint(\"{}: GT, {}\\n\", T, gt_pose_world[:3,3])\n", + " gt_pose_world = cam_pose @ gt_poses[T - 1]\n", + " jprint(\"{}: {}, {}\", T, unique_best_3dp3_score, pose_world[:3, 3])\n", + " jprint(\"{}: GT, {}\\n\", T, gt_pose_world[:3, 3])\n", "\n", " # proposals = jnp.einsum(\"ij,ajk->aik\", pose_estimate, rotation_deltas)\n", " # rendered_images = b.RENDERER.render_multiobject_parallel(jnp.stack([proposals, occ_poses_rot]), jnp.array([0,1]))\n", @@ -393,7 +436,7 @@ " # weights_new = b.threedp3_likelihood_parallel(gt_image, rendered_images, 0.05, 0.1, 10**3, 3)\n", " # pose_estimate = proposals[jnp.argmax(weights_new)]\n", "\n", - " return (pose_memory, T+1), pose_estimate" + " return (pose_memory, T + 1), pose_estimate" ] }, { @@ -682,16 +725,17 @@ ], "source": [ "importlib.reload(p)\n", - "inference_program = jax.jit(lambda p,x: jax.lax.scan(\n", - " update_pose_estimate, \n", - " (jnp.tile(p, (x.shape[0]+1,1,1)),1),\n", - " x)[1])\n", + "inference_program = jax.jit(\n", + " lambda p, x: jax.lax.scan(\n", + " update_pose_estimate, (jnp.tile(p, (x.shape[0] + 1, 1, 1)), 1), x\n", + " )[1]\n", + ")\n", "\n", "start = time.time()\n", "inferred_poses = inference_program(gt_poses[0], gt_images)\n", "end = time.time()\n", - "print (\"Time elapsed:\", end - start)\n", - "print (\"FPS:\", gt_poses.shape[0] / (end - start))" + "print(\"Time elapsed:\", end - start)\n", + "print(\"FPS:\", gt_poses.shape[0] / (end - start))" ] }, { @@ -704,21 +748,34 @@ "max_depth = 10.0\n", "\n", "# inferred_poses_with_occ = jnp.stack([inferred_poses, occ1_poses], axis = 1)\n", - "occ_image = b.viz.get_depth_image(b.RENDERER.render(occ1_pose[None,...], jnp.array([1]))[:,:,2])\n", + "occ_image = b.viz.get_depth_image(\n", + " b.RENDERER.render(occ1_pose[None, ...], jnp.array([1]))[:, :, 2]\n", + ")\n", "\n", - "pred_images = b.RENDERER.render_many(inferred_poses[:,None, ...], jnp.array([0]))\n", + "pred_images = b.RENDERER.render_many(inferred_poses[:, None, ...], jnp.array([0]))\n", "\n", - "pred_with_occ_images = [b.overlay_image(b.viz.get_depth_image(pred_images[i,:,:,2]), \n", - "occ_image, alpha=0.4) for i in range(pred_images.shape[0])]\n", + "pred_with_occ_images = [\n", + " b.overlay_image(\n", + " b.viz.get_depth_image(pred_images[i, :, :, 2]), occ_image, alpha=0.4\n", + " )\n", + " for i in range(pred_images.shape[0])\n", + "]\n", "\n", - "gt_images = b.RENDERER.render_many(gt_poses[:,None, ...], jnp.array([0]))\n", - "gt_with_occ_images = [b.overlay_image(b.viz.get_depth_image(gt_images[i,:,:,2]), occ_image, alpha=0.5) for i in range(pred_images.shape[0])]\n", + "gt_images = b.RENDERER.render_many(gt_poses[:, None, ...], jnp.array([0]))\n", + "gt_with_occ_images = [\n", + " b.overlay_image(b.viz.get_depth_image(gt_images[i, :, :, 2]), occ_image, alpha=0.5)\n", + " for i in range(pred_images.shape[0])\n", + "]\n", "\n", "viz_images = [\n", " b.viz.multi_panel(\n", - " [g, b.viz.get_depth_image(p[:,:,2]), po],\n", - " labels = [\"Ground Truth\", \"Reconstruction w/o Occluder\", \"Reconstruction w Occluder\"],\n", - " title = \"External Force Demo\",\n", + " [g, b.viz.get_depth_image(p[:, :, 2]), po],\n", + " labels=[\n", + " \"Ground Truth\",\n", + " \"Reconstruction w/o Occluder\",\n", + " \"Reconstruction w Occluder\",\n", + " ],\n", + " title=\"External Force Demo\",\n", " # bottom_text = \"3DP3 + Physics Prior v1\"\n", " )\n", " for (g, p, po) in zip(gt_with_occ_images, pred_images, pred_with_occ_images)\n", @@ -732,8 +789,8 @@ "metadata": {}, "outputs": [], "source": [ - "def make_gif(images, filename, fps = 10):\n", - " duration = int(1000/fps)\n", + "def make_gif(images, filename, fps=10):\n", + " duration = int(1000 / fps)\n", " images[0].save(\n", " fp=filename,\n", " format=\"GIF\",\n", @@ -743,6 +800,7 @@ " loop=0,\n", " )\n", "\n", + "\n", "make_gif(viz_images, \"scene_gifs/external_force_scene_demo.gif\", fps=20)" ] } diff --git a/scripts/experiments/mcs/physics.ipynb b/scripts/experiments/mcs/physics.ipynb index 393d27b3..73ced2dd 100644 --- a/scripts/experiments/mcs/physics.ipynb +++ b/scripts/experiments/mcs/physics.ipynb @@ -25,7 +25,12 @@ "metadata": {}, "outputs": [], "source": [ - "scene_regex = os.path.join(b.utils.get_assets_dir(), \"mcs_scene_jsons\", \"eval_6_validation\", \"passive_physics_spatio_temporal_continuity*\")\n", + "scene_regex = os.path.join(\n", + " b.utils.get_assets_dir(),\n", + " \"mcs_scene_jsons\",\n", + " \"eval_6_validation\",\n", + " \"passive_physics_spatio_temporal_continuity*\",\n", + ")\n", "# scene_regex = os.path.join(j.utils.get_assets_dir(), \"mcs_scene_jsons\", \"eval_6_validation\", \"passive_physics_object*\")\n", "files = sorted(glob.glob(scene_regex))\n", "files" @@ -41,13 +46,15 @@ "def load_mcs_scene_data(scene_path):\n", " cache_dir = os.path.join(b.utils.get_assets_dir(), \"mcs_cache\")\n", " scene_name = scene_path.split(\"/\")[-1]\n", - " \n", + "\n", " cache_filename = os.path.join(cache_dir, f\"{scene_name}.npz\")\n", " if os.path.exists(cache_filename):\n", - " images = np.load(cache_filename,allow_pickle=True)[\"arr_0\"]\n", + " images = np.load(cache_filename, allow_pickle=True)[\"arr_0\"]\n", " else:\n", " controller = mcs.create_controller(\n", - " os.path.join(b.utils.get_assets_dir(), \"mcs_scene_jsons\", \"config_level2.ini\")\n", + " os.path.join(\n", + " b.utils.get_assets_dir(), \"mcs_scene_jsons\", \"config_level2.ini\"\n", + " )\n", " )\n", "\n", " scene_data = mcs.load_scene_json_file(scene_path)\n", @@ -89,7 +96,6 @@ "images = load_mcs_scene_data(filename)[85:]\n", "\n", "\n", - "\n", "filename = files[1]\n", "images = load_mcs_scene_data(filename)[85:]\n", "\n", @@ -121,8 +127,7 @@ "# images = jax3dp3.physics.load_mcs_scene_data(filename)[90:152]\n", "\n", "# filename = files[9]\n", - "# images = jax3dp3.physics.load_mcs_scene_data(filename)[90:152]\n", - "\n" + "# images = jax3dp3.physics.load_mcs_scene_data(filename)[90:152]" ] }, { @@ -157,29 +162,28 @@ "original_intrinsics = images[0].intrinsics\n", "intrinsics = j.camera.scale_camera_parameters(original_intrinsics, 0.25)\n", "intrinsics = j.Intrinsics(\n", - " intrinsics.height, intrinsics.width,\n", + " intrinsics.height,\n", + " intrinsics.width,\n", " intrinsics.fx,\n", " intrinsics.fy,\n", " intrinsics.cx,\n", " intrinsics.cy,\n", " intrinsics.near,\n", - " WALL_Z + 0.1\n", + " WALL_Z + 0.1,\n", ")\n", "\n", "\n", - "dx = 0.7\n", + "dx = 0.7\n", "dy = 0.7\n", "dz = 0.7\n", "gridding = [\n", + " j.make_translation_grid_enumeration(-dx, -dy, -dz, dx, dy, dz, 21, 15, 15),\n", " j.make_translation_grid_enumeration(\n", - " -dx, -dy, -dz, dx, dy, dz, 21,15,15\n", + " -dx / 2.0, -dy / 2, -dz / 2, dx / 2, dy / 2, dz / 2, 21, 15, 15\n", " ),\n", " j.make_translation_grid_enumeration(\n", - " -dx/2.0, -dy/2, -dz/2, dx/2, dy/2, dz/2, 21,15,15\n", + " -dx / 10.0, -dy / 10, -dz / 10, dx / 10, dy / 10, dz / 10, 21, 15, 15\n", " ),\n", - " j.make_translation_grid_enumeration(\n", - " -dx/10.0, -dy/10, -dz/10, dx/10, dy/10, dz/10, 21,15,15\n", - " )\n", "]" ] }, @@ -196,32 +200,40 @@ " for id in segmentation_ids:\n", " point_cloud_segment = point_cloud_image[segmentation == id]\n", " bbox_dims, pose = j.utils.aabb(point_cloud_segment)\n", - " is_occluder = jnp.logical_or(jnp.logical_or(jnp.logical_or(jnp.logical_or(\n", - " (bbox_dims[0] < 0.1),\n", - " (bbox_dims[1] < 0.1)),\n", - " (bbox_dims[1] > 1.1)),\n", - " (bbox_dims[0] > 1.1)),\n", - " (bbox_dims[2] > 2.1)\n", + " is_occluder = jnp.logical_or(\n", + " jnp.logical_or(\n", + " jnp.logical_or(\n", + " jnp.logical_or((bbox_dims[0] < 0.1), (bbox_dims[1] < 0.1)),\n", + " (bbox_dims[1] > 1.1),\n", + " ),\n", + " (bbox_dims[0] > 1.1),\n", + " ),\n", + " (bbox_dims[2] > 2.1),\n", " )\n", " if not is_occluder:\n", - " object_mask += (segmentation == id)\n", + " object_mask += segmentation == id\n", " object_ids.append(id)\n", "\n", " object_mask = jnp.array(object_mask) > 0\n", " return object_ids, object_mask\n", "\n", - "def prior3(new_pose, prev_pose, prev_prev_pose, bbox_dims): \n", + "\n", + "def prior3(new_pose, prev_pose, prev_prev_pose, bbox_dims):\n", " score = 0.0\n", - " new_position = new_pose[:3,3]\n", - " bottom_of_object_y = new_position[1] + bbox_dims[1]/2.0\n", + " new_position = new_pose[:3, 3]\n", + " bottom_of_object_y = new_position[1] + bbox_dims[1] / 2.0\n", "\n", - " prev_position = prev_pose[:3,3]\n", - " prev_prev_position = prev_prev_pose[:3,3]\n", + " prev_position = prev_pose[:3, 3]\n", + " prev_prev_position = prev_prev_pose[:3, 3]\n", "\n", " velocity_prev = (prev_position - prev_prev_position) * jnp.array([1.0, 1.0, 0.25])\n", - " velocity_with_gravity = velocity_prev + jnp.array([-jnp.sign(velocity_prev[0])*0.01, 0.02, 0.0])\n", + " velocity_with_gravity = velocity_prev + jnp.array(\n", + " [-jnp.sign(velocity_prev[0]) * 0.01, 0.02, 0.0]\n", + " )\n", "\n", - " velocity_with_gravity2 = velocity_with_gravity * jnp.array([1.0 * (jnp.abs(velocity_with_gravity[0]) > 0.1), 1.0, 1.0 ])\n", + " velocity_with_gravity2 = velocity_with_gravity * jnp.array(\n", + " [1.0 * (jnp.abs(velocity_with_gravity[0]) > 0.1), 1.0, 1.0]\n", + " )\n", " velocity = velocity_with_gravity2\n", "\n", " pred_new_position = prev_position + velocity\n", @@ -232,26 +244,31 @@ " score += -100.0 * (bottom_of_object_y > 1.5)\n", " return score\n", "\n", + "\n", "prior_jit = jax.jit(prior3)\n", - "prior_parallel_jit = jax.jit(jax.vmap(prior3, in_axes=(0, None, None, None)))\n", + "prior_parallel_jit = jax.jit(jax.vmap(prior3, in_axes=(0, None, None, None)))\n", + "\n", "\n", "def update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES):\n", " for known_id in range(OBJECT_POSES.shape[0]):\n", - "\n", " current_pose_estimate = OBJECT_POSES[known_id, :, :]\n", "\n", " for gridding_iter in range(len(gridding)):\n", " all_pose_proposals = [\n", - " jnp.einsum(\"aij,jk->aik\", \n", + " jnp.einsum(\n", + " \"aij,jk->aik\",\n", " gridding[gridding_iter],\n", " current_pose_estimate,\n", " )\n", " ]\n", " if gridding_iter == 0:\n", " for seg_id in object_ids:\n", - " _, center_pose = j.utils.aabb(point_cloud_image[segmentation==seg_id])\n", + " _, center_pose = j.utils.aabb(\n", + " point_cloud_image[segmentation == seg_id]\n", + " )\n", " all_pose_proposals.append(\n", - " jnp.einsum(\"aij,jk->aik\", \n", + " jnp.einsum(\n", + " \"aij,jk->aik\",\n", " gridding[gridding_iter],\n", " center_pose,\n", " )\n", @@ -259,20 +276,26 @@ " all_pose_proposals = jnp.vstack(all_pose_proposals)\n", "\n", " all_weights = []\n", - " for batch in jnp.array_split(all_pose_proposals,3):\n", - " \n", - " rendered_images = renderer.render_parallel(batch, known_id)[...,:3]\n", - " rendered_images_spliced = j.splice_image_parallel(rendered_images, point_cloud_image_complement)\n", + " for batch in jnp.array_split(all_pose_proposals, 3):\n", + " rendered_images = renderer.render_parallel(batch, known_id)[..., :3]\n", + " rendered_images_spliced = j.splice_image_parallel(\n", + " rendered_images, point_cloud_image_complement\n", + " )\n", "\n", " weights = j.threedp3_likelihood_parallel_jit(\n", - " point_cloud_image, rendered_images_spliced, R, OUTLIER_PROB, OUTLIER_VOLUME, 3\n", + " point_cloud_image,\n", + " rendered_images_spliced,\n", + " R,\n", + " OUTLIER_PROB,\n", + " OUTLIER_VOLUME,\n", + " 3,\n", " ).reshape(-1)\n", "\n", " prev_pose = ALL_OBJECT_POSES[-1][known_id]\n", " if ALL_OBJECT_POSES[-2].shape[0] <= known_id:\n", - " prev_prev_pose = ALL_OBJECT_POSES[-1][known_id]\n", + " prev_prev_pose = ALL_OBJECT_POSES[-1][known_id]\n", " else:\n", - " prev_prev_pose = ALL_OBJECT_POSES[-2][known_id]\n", + " prev_prev_pose = ALL_OBJECT_POSES[-2][known_id]\n", "\n", " weights += prior_parallel_jit(\n", " batch, prev_pose, prev_prev_pose, renderer.model_box_dims[known_id]\n", @@ -286,6 +309,7 @@ " OBJECT_POSES = OBJECT_POSES.at[known_id].set(current_pose_estimate)\n", " return OBJECT_POSES\n", "\n", + "\n", "def add_new_objects(OBJECT_POSES):\n", " for seg_id in object_ids:\n", " average_probability = jnp.mean(pixelwise_probs[segmentation == seg_id])\n", @@ -301,8 +325,12 @@ " continue\n", "\n", " rows, cols = jnp.where(segmentation == seg_id)\n", - " distance_to_edge_1 = min(jnp.abs(rows - 0).min(), jnp.abs(rows - intrinsics.height).min())\n", - " distance_to_edge_2 = min(jnp.abs(cols - 0).min(), jnp.abs(cols - intrinsics.width).min())\n", + " distance_to_edge_1 = min(\n", + " jnp.abs(rows - 0).min(), jnp.abs(rows - intrinsics.height).min()\n", + " )\n", + " distance_to_edge_2 = min(\n", + " jnp.abs(cols - 0).min(), jnp.abs(cols - intrinsics.width).min()\n", + " )\n", "\n", " point_cloud_segment = point_cloud_image[segmentation == seg_id]\n", " dims, pose = j.utils.aabb(point_cloud_segment)\n", @@ -315,21 +343,33 @@ "\n", " resolution = 0.01\n", " voxelized = jnp.rint(point_cloud_segment / resolution).astype(jnp.int32)\n", - " min_z = voxelized[:,2].min()\n", - " depth = voxelized[:,2].max() - voxelized[:,2].min()\n", + " min_z = voxelized[:, 2].min()\n", + " depth = voxelized[:, 2].max() - voxelized[:, 2].min()\n", "\n", - " front_face = voxelized[voxelized[:,2] <= min_z+20, :]\n", + " front_face = voxelized[voxelized[:, 2] <= min_z + 20, :]\n", " slices = [front_face]\n", " for i in range(depth):\n", " slices.append(front_face + jnp.array([0.0, 0.0, i]))\n", " full_shape = jnp.vstack(slices) * resolution\n", "\n", - " print(\"Seg ID: \", seg_id, \"Prob: \", average_probability, \" Pixels: \",num_pixels, \" dists: \", distance_to_edge_1, \" \", distance_to_edge_2, \" Pose: \", pose[:3, 3])\n", + " print(\n", + " \"Seg ID: \",\n", + " seg_id,\n", + " \"Prob: \",\n", + " average_probability,\n", + " \" Pixels: \",\n", + " num_pixels,\n", + " \" dists: \",\n", + " distance_to_edge_1,\n", + " \" \",\n", + " distance_to_edge_2,\n", + " \" Pose: \",\n", + " pose[:3, 3],\n", + " )\n", "\n", " dims, pose = j.utils.aabb(full_shape)\n", " mesh = j.mesh.make_marching_cubes_mesh_from_point_cloud(\n", - " j.t3d.apply_transform(full_shape, j.t3d.inverse_pose(pose)),\n", - " 0.075\n", + " j.t3d.apply_transform(full_shape, j.t3d.inverse_pose(pose)), 0.075\n", " )\n", "\n", " renderer.add_mesh(mesh)\n", @@ -347,12 +387,12 @@ "outputs": [], "source": [ "R = 0.01\n", - "OUTLIER_PROB=0.01\n", - "OUTLIER_VOLUME=100.0\n", + "OUTLIER_PROB = 0.01\n", + "OUTLIER_VOLUME = 100.0\n", "ALL_OBJECT_POSES = [jnp.zeros((0, 4, 4))]\n", "t = 0\n", "\n", - "renderer = j.Renderer(intrinsics)\n" + "renderer = j.Renderer(intrinsics)" ] }, { @@ -367,24 +407,34 @@ " image = images[t]\n", " depth = j.utils.resize(image.depth, intrinsics.height, intrinsics.width)\n", " point_cloud_image = j.t3d.unproject_depth(depth, intrinsics)\n", - " segmentation = j.utils.resize(image.segmentation, intrinsics.height, intrinsics.width)\n", + " segmentation = j.utils.resize(\n", + " image.segmentation, intrinsics.height, intrinsics.width\n", + " )\n", " segmentation_ids = jnp.unique(segmentation)\n", - " object_ids, object_mask = get_object_mask(point_cloud_image, segmentation, segmentation_ids)\n", + " object_ids, object_mask = get_object_mask(\n", + " point_cloud_image, segmentation, segmentation_ids\n", + " )\n", " j.get_depth_image(1.0 * object_mask)\n", " depth_complement = depth * (1.0 - object_mask) + intrinsics.far * (object_mask)\n", " point_cloud_image_complement = j.t3d.unproject_depth(depth_complement, intrinsics)\n", "\n", " OBJECT_POSES = jnp.array(ALL_OBJECT_POSES[-1])\n", " OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)\n", - "# OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)\n", - "# OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)\n", - "\n", - " rerendered = renderer.render_multiobject(OBJECT_POSES, jnp.arange(OBJECT_POSES.shape[0]))[...,:3]\n", - " rerendered_spliced = j.splice_image_parallel(jnp.array([rerendered]), point_cloud_image_complement)[0]\n", - " pixelwise_probs = j.threedp3_likelihood_per_pixel_jit(point_cloud_image, rerendered_spliced, R, 0.0, 1.0, 5)\n", + " # OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)\n", + " # OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)\n", + "\n", + " rerendered = renderer.render_multiobject(\n", + " OBJECT_POSES, jnp.arange(OBJECT_POSES.shape[0])\n", + " )[..., :3]\n", + " rerendered_spliced = j.splice_image_parallel(\n", + " jnp.array([rerendered]), point_cloud_image_complement\n", + " )[0]\n", + " pixelwise_probs = j.threedp3_likelihood_per_pixel_jit(\n", + " point_cloud_image, rerendered_spliced, R, 0.0, 1.0, 5\n", + " )\n", "\n", " OBJECT_POSES = add_new_objects(OBJECT_POSES)\n", - " \n", + "\n", " ALL_OBJECT_POSES.append(OBJECT_POSES)" ] }, @@ -400,39 +450,46 @@ " image = images[t]\n", " depth = j.utils.resize(image.depth, intrinsics.height, intrinsics.width)\n", " point_cloud_image = j.t3d.unproject_depth(depth, intrinsics)\n", - " segmentation = j.utils.resize(image.segmentation, intrinsics.height, intrinsics.width)\n", + " segmentation = j.utils.resize(\n", + " image.segmentation, intrinsics.height, intrinsics.width\n", + " )\n", " segmentation_ids = jnp.unique(segmentation)\n", - "# object_ids, object_mask = j.physics.get_object_mask(point_cloud_image, segmentation, segmentation_ids)\n", + " # object_ids, object_mask = j.physics.get_object_mask(point_cloud_image, segmentation, segmentation_ids)\n", " j.get_depth_image(1.0 * object_mask)\n", " depth_complement = depth * (1.0 - object_mask) + intrinsics.far * (object_mask)\n", " point_cloud_image_complement = j.t3d.unproject_depth(depth_complement, intrinsics)\n", "\n", " OBJECT_POSES = ALL_OBJECT_POSES[t]\n", - " rerendered = renderer.render_multiobject(OBJECT_POSES, jnp.arange(OBJECT_POSES.shape[0]))\n", - " rerendered_spliced = j.splice_image_parallel(jnp.array([rerendered[...,:3]]), point_cloud_image_complement)[0]\n", - " pixelwise_probs = j.threedp3_likelihood_per_pixel_jit(point_cloud_image, rerendered_spliced, R, 0.0, 1.0, 5)\n", - "\n", + " rerendered = renderer.render_multiobject(\n", + " OBJECT_POSES, jnp.arange(OBJECT_POSES.shape[0])\n", + " )\n", + " rerendered_spliced = j.splice_image_parallel(\n", + " jnp.array([rerendered[..., :3]]), point_cloud_image_complement\n", + " )[0]\n", + " pixelwise_probs = j.threedp3_likelihood_per_pixel_jit(\n", + " point_cloud_image, rerendered_spliced, R, 0.0, 1.0, 5\n", + " )\n", "\n", " weights = []\n", " if t >= 2:\n", " for known_id in range(len(ALL_OBJECT_POSES[t])):\n", - " if ALL_OBJECT_POSES[t-1].shape[0] <= known_id:\n", + " if ALL_OBJECT_POSES[t - 1].shape[0] <= known_id:\n", " continue\n", "\n", - " if ALL_OBJECT_POSES[t-2].shape[0] <= known_id:\n", + " if ALL_OBJECT_POSES[t - 2].shape[0] <= known_id:\n", " continue\n", "\n", - " prev_pose = ALL_OBJECT_POSES[t-1][known_id]\n", - " if ALL_OBJECT_POSES[t-2].shape[0] <= known_id:\n", - " prev_prev_pose = ALL_OBJECT_POSES[t-1][known_id]\n", + " prev_pose = ALL_OBJECT_POSES[t - 1][known_id]\n", + " if ALL_OBJECT_POSES[t - 2].shape[0] <= known_id:\n", + " prev_prev_pose = ALL_OBJECT_POSES[t - 1][known_id]\n", " else:\n", - " prev_prev_pose = ALL_OBJECT_POSES[t-2][known_id]\n", + " prev_prev_pose = ALL_OBJECT_POSES[t - 2][known_id]\n", "\n", " weight = prior_jit(\n", " ALL_OBJECT_POSES[t][known_id],\n", - " ALL_OBJECT_POSES[t-1][known_id],\n", - " ALL_OBJECT_POSES[t-2][known_id],\n", - " renderer.model_box_dims[known_id]\n", + " ALL_OBJECT_POSES[t - 1][known_id],\n", + " ALL_OBJECT_POSES[t - 2][known_id],\n", + " renderer.model_box_dims[known_id],\n", " ).reshape(-1)\n", " weights.append(weight)\n", "\n", @@ -443,7 +500,7 @@ " rerendered,\n", " rerendered_spliced,\n", " pixelwise_probs,\n", - " weights\n", + " weights,\n", " )\n", " )\n", "weights_over_time = [jnp.array(d[-1]).sum() for d in data]" @@ -460,21 +517,23 @@ "import io\n", "import numpy as np\n", "from PIL import Image\n", + "\n", "t = 50\n", "\n", - "def make_plot(x,y, xlabel):\n", + "\n", + "def make_plot(x, y, xlabel):\n", " plt.clf()\n", - " color = np.array([229, 107, 111])/255.0\n", - " plt.plot(x,y, color=color)\n", + " color = np.array([229, 107, 111]) / 255.0\n", + " plt.plot(x, y, color=color)\n", " plt.xlim(0, len(images))\n", " plt.ylim(-800.0, 100.0)\n", - " plt.xlabel(\"Time\",fontsize=20)\n", - " plt.ylabel(\"Log Probability\",fontsize=20)\n", + " plt.xlabel(\"Time\", fontsize=20)\n", + " plt.ylabel(\"Log Probability\", fontsize=20)\n", " plt.tight_layout()\n", " img_buf = io.BytesIO()\n", - " plt.savefig(img_buf, format='png')\n", + " plt.savefig(img_buf, format=\"png\")\n", " im = Image.open(img_buf)\n", - " return im\n" + " return im" ] }, { @@ -496,18 +555,25 @@ "source": [ "viz_panels = []\n", "for t in tqdm(range(len(images))):\n", - " rgb, point_cloud_image, rerendered, rerendered_spliced, pixelwise_probs, weights = data[t]\n", + " rgb, point_cloud_image, rerendered, rerendered_spliced, pixelwise_probs, weights = (\n", + " data[t]\n", + " )\n", " plots = make_plot(np.arange(t), weights_over_time[:t], \"Time\")\n", " factor = rgb.shape[0] / plots.height\n", "\n", - "\n", - " v = j.multi_panel([\n", - " j.get_rgb_image(rgb),\n", - " j.scale_image(j.get_depth_image(rerendered[:,:,2], min=4.0,max=15.0),4),\n", - " j.overlay_image(j.scale_image(j.get_depth_image(rerendered[:,:,2], min=4.0,max=15.0),4), j.get_rgb_image(rgb)),\n", - " j.scale_image(plots, factor)\n", - " ],\n", - " [\"Observed RGB\", \"Inferred Objects\", \"Overlay\", \"Probability\"],\n", + " v = j.multi_panel(\n", + " [\n", + " j.get_rgb_image(rgb),\n", + " j.scale_image(j.get_depth_image(rerendered[:, :, 2], min=4.0, max=15.0), 4),\n", + " j.overlay_image(\n", + " j.scale_image(\n", + " j.get_depth_image(rerendered[:, :, 2], min=4.0, max=15.0), 4\n", + " ),\n", + " j.get_rgb_image(rgb),\n", + " ),\n", + " j.scale_image(plots, factor),\n", + " ],\n", + " [\"Observed RGB\", \"Inferred Objects\", \"Overlay\", \"Probability\"],\n", " label_fontsize=50,\n", " )\n", " viz_panels.append(v)\n", @@ -549,7 +615,7 @@ "metadata": {}, "outputs": [], "source": [ - "j.meshcat.show_trimesh(\"1\",renderer.meshes[1])" + "j.meshcat.show_trimesh(\"1\", renderer.meshes[1])" ] }, { @@ -583,7 +649,7 @@ "metadata": {}, "outputs": [], "source": [ - "ALL_OBJECT_POSES.append(OBJECT_POSES)\n" + "ALL_OBJECT_POSES.append(OBJECT_POSES)" ] }, { @@ -601,13 +667,16 @@ "metadata": {}, "outputs": [], "source": [ - "r = jnp.ones_like(point_cloud_image[:,:,2]) * 0.005\n", + "r = jnp.ones_like(point_cloud_image[:, :, 2]) * 0.005\n", "key = jax.random.PRNGKey(10)\n", "noisy_point_cloud_image = jax.random.multivariate_normal(\n", - " key, point_cloud_image[:,:,:3], (jnp.eye(3)[None, None, :, :] * r[:,:,None,None]), shape=r.shape\n", + " key,\n", + " point_cloud_image[:, :, :3],\n", + " (jnp.eye(3)[None, None, :, :] * r[:, :, None, None]),\n", + " shape=r.shape,\n", ")\n", - "img = j.render_point_cloud(noisy_point_cloud_image.reshape(-1,3), intrinsics)\n", - "j.scale_image(j.get_depth_image(img[:,:,2]),10)" + "img = j.render_point_cloud(noisy_point_cloud_image.reshape(-1, 3), intrinsics)\n", + "j.scale_image(j.get_depth_image(img[:, :, 2]), 10)" ] }, { diff --git a/scripts/experiments/slam/slam.ipynb b/scripts/experiments/slam/slam.ipynb index a7a03aa6..f6152d86 100644 --- a/scripts/experiments/slam/slam.ipynb +++ b/scripts/experiments/slam/slam.ipynb @@ -51,15 +51,15 @@ } ], "source": [ - "f = open(os.path.join(b.utils.get_assets_dir(), f\"tum/livingRoom1.gt.freiburg\"),\"r\")\n", + "f = open(os.path.join(b.utils.get_assets_dir(), f\"tum/livingRoom1.gt.freiburg\"), \"r\")\n", "data = f.readlines()\n", - "data = [d.strip('\\n') for d in data]\n", + "data = [d.strip(\"\\n\") for d in data]\n", "poses = [jnp.eye(4)]\n", "\n", "xyzw_to_rotation_matrix = jax.jit(b.t3d.xyzw_to_rotation_matrix)\n", "transform_from_rot_and_pos = jax.jit(b.t3d.transform_from_rot_and_pos)\n", "for i in tqdm(range(len(data))):\n", - " xyzq = list(map(float,data[i].split(\" \")))[1:]\n", + " xyzq = list(map(float, data[i].split(\" \")))[1:]\n", " pos = jnp.array([xyzq[:3]])\n", " rot = b.xyzw_to_rotation_matrix(jnp.array(xyzq[3:]))\n", " pose = transform_from_rot_and_pos(rot, pos)\n", @@ -95,58 +95,79 @@ { "cell_type": "code", "execution_count": 5, + "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "outputs": [], "source": [ "original_intrinsics = b.Intrinsics(\n", - " original_depths[0].shape[0], original_depths[1].shape[1],\n", - " 481.20, 480.00,319.50,239.50,0.001, 6.0\n", + " original_depths[0].shape[0],\n", + " original_depths[1].shape[1],\n", + " 481.20,\n", + " 480.00,\n", + " 319.50,\n", + " 239.50,\n", + " 0.001,\n", + " 6.0,\n", ")\n", "intrinsics = b.camera.scale_camera_parameters(original_intrinsics, 0.2)\n", - "depths = [b.utils.resize(d, intrinsics.height, intrinsics.width) for d in original_depths]" + "depths = [\n", + " b.utils.resize(d, intrinsics.height, intrinsics.width) for d in original_depths\n", + "]" ] }, { "cell_type": "code", "execution_count": 6, + "id": "acae54e37e7d407bbb7b55eff062a284", "metadata": {}, "outputs": [], "source": [ "from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer\n", + "\n", "jax_renderer = JaxRenderer(intrinsics)" ] }, { "cell_type": "code", "execution_count": 7, + "id": "9a63283cbaf04dbcab1f6479b197f3a8", "metadata": {}, "outputs": [], "source": [ "b.clear()\n", "point_cloud_first = b.unproject_depth_jit(depths[0], intrinsics)\n", - "b.show_cloud(\"1\", point_cloud_first.reshape(-1,3))" + "b.show_cloud(\"1\", point_cloud_first.reshape(-1, 3))" ] }, { "cell_type": "code", "execution_count": 9, + "id": "8dd0d8092fe74a7c96281538738b07e2", "metadata": {}, "outputs": [], "source": [ "def point_cloud_image_to_trimesh(point_cloud_image):\n", " height, width, _ = point_cloud_image.shape\n", - " ij_to_index = lambda i,j: i * width + j\n", + " ij_to_index = lambda i, j: i * width + j\n", " ij_to_faces = lambda ij: jnp.array(\n", " [\n", - " [ij_to_index(ij[0], ij[1]), ij_to_index(ij[0]+1, ij[1]), ij_to_index(ij[0], ij[1]+1)],\n", - " [ij_to_index(ij[0]+1, ij[1]), ij_to_index(ij[0]+1, ij[1]+1), ij_to_index(ij[0], ij[1]+1)]\n", + " [\n", + " ij_to_index(ij[0], ij[1]),\n", + " ij_to_index(ij[0] + 1, ij[1]),\n", + " ij_to_index(ij[0], ij[1] + 1),\n", + " ],\n", + " [\n", + " ij_to_index(ij[0] + 1, ij[1]),\n", + " ij_to_index(ij[0] + 1, ij[1] + 1),\n", + " ij_to_index(ij[0], ij[1] + 1),\n", + " ],\n", " ]\n", " )\n", - " jj, ii = jnp.meshgrid(jnp.arange(width-1), jnp.arange(height-1))\n", - " indices = jnp.stack([ii,jj],axis=-1)\n", - " faces = jax.vmap(ij_to_faces)(indices.reshape(-1,2)).reshape(-1,3)\n", + " jj, ii = jnp.meshgrid(jnp.arange(width - 1), jnp.arange(height - 1))\n", + " indices = jnp.stack([ii, jj], axis=-1)\n", + " faces = jax.vmap(ij_to_faces)(indices.reshape(-1, 2)).reshape(-1, 3)\n", " print(faces.shape)\n", - " vertices = point_cloud_image.reshape(-1,3)\n", + " vertices = point_cloud_image.reshape(-1, 3)\n", " mesh = trimesh.Trimesh(vertices, faces)\n", " return mesh" ] @@ -154,6 +175,7 @@ { "cell_type": "code", "execution_count": 11, + "id": "72eea5119410473aa328ad9291626812", "metadata": {}, "outputs": [ { @@ -166,34 +188,47 @@ ], "source": [ "mesh = point_cloud_image_to_trimesh(point_cloud_first)\n", - "b.show_trimesh(\"mesh\", mesh) \n", - "vertices,faces = jnp.array(mesh.vertices), jnp.array(mesh.faces)" + "b.show_trimesh(\"mesh\", mesh)\n", + "vertices, faces = jnp.array(mesh.vertices), jnp.array(mesh.faces)" ] }, { "cell_type": "code", "execution_count": 110, + "id": "8edb47106e1a46a883d545849b8ab81b", "metadata": {}, "outputs": [], "source": [ - "def render_image_from_pose(trans,q):\n", + "def render_image_from_pose(trans, q):\n", " camera_pose = b.translation_and_quaternion_to_pose_matrix(trans, q)\n", - " img = jax_renderer.render(vertices, faces, b.inverse_pose(camera_pose), intrinsics)[0][0,...]\n", + " img = jax_renderer.render(vertices, faces, b.inverse_pose(camera_pose), intrinsics)[\n", + " 0\n", + " ][0, ...]\n", " return img\n", "\n", + "\n", "def loss_func(trans, q, obs_depth):\n", - " img = render_image_from_pose(trans,q)\n", + " img = render_image_from_pose(trans, q)\n", " diffs = jnp.abs(img - obs_depth)\n", " return (diffs * (img > 0.001)).mean()\n", " # return jnp.abs(img - obs_depth).mean()\n", - " \n", "\n", - "value_and_grad_jit = jax.jit(jax.value_and_grad(loss_func, argnums=(0,1,)))" + "\n", + "value_and_grad_jit = jax.jit(\n", + " jax.value_and_grad(\n", + " loss_func,\n", + " argnums=(\n", + " 0,\n", + " 1,\n", + " ),\n", + " )\n", + ")" ] }, { "cell_type": "code", "execution_count": 116, + "id": "10185d26023b46108eb7d9f57d49d2b3", "metadata": {}, "outputs": [ { @@ -222,20 +257,25 @@ } ], "source": [ - "tr,q = jnp.zeros(3), jnp.array([1.0, 0.0, 0.0, 0.0])\n", + "tr, q = jnp.zeros(3), jnp.array([1.0, 0.0, 0.0, 0.0])\n", "\n", "timestep = 0\n", "\n", - "mesh = point_cloud_image_to_trimesh(b.apply_transform(b.unproject_depth_jit(depths[timestep], intrinsics), b.translation_and_quaternion_to_pose_matrix(tr,q)))\n", + "mesh = point_cloud_image_to_trimesh(\n", + " b.apply_transform(\n", + " b.unproject_depth_jit(depths[timestep], intrinsics),\n", + " b.translation_and_quaternion_to_pose_matrix(tr, q),\n", + " )\n", + ")\n", "b.clear()\n", - "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)\n", - "b.show_trimesh(\"mesh\", mesh) \n", - "vertices,faces = jnp.array(mesh.vertices), jnp.array(mesh.faces)\n", + "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr, q), size=0.1)\n", + "b.show_trimesh(\"mesh\", mesh)\n", + "vertices, faces = jnp.array(mesh.vertices), jnp.array(mesh.faces)\n", "\n", - "print(b.translation_and_quaternion_to_pose_matrix(tr,q))\n", - "img = render_image_from_pose(tr,q)\n", + "print(b.translation_and_quaternion_to_pose_matrix(tr, q))\n", + "img = render_image_from_pose(tr, q)\n", "print(img.shape)\n", - "print(loss_func(tr,q, depths[0]))\n", + "print(loss_func(tr, q, depths[0]))\n", "\n", "b.hstack_images([b.get_depth_image(img), b.get_depth_image(depths[0])])" ] @@ -243,6 +283,7 @@ { "cell_type": "code", "execution_count": 123, + "id": "8763a12b2bbd4a93a75aff182afb95dc", "metadata": {}, "outputs": [ { @@ -256,17 +297,18 @@ "source": [ "pbar = tqdm(range(200))\n", "timestep = 100\n", - "for _ in pbar:\n", + "for _ in pbar:\n", " loss, (g1, g2) = value_and_grad_jit(tr, q, depths[timestep])\n", " tr -= g1 * 0.001\n", " q -= g2 * 0.001\n", " pbar.set_description(f\"{loss}\")\n", - "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)" + "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr, q), size=0.1)" ] }, { "cell_type": "code", "execution_count": 125, + "id": "7623eae2785240b9bd12b16a66d81610", "metadata": {}, "outputs": [ { @@ -282,21 +324,21 @@ } ], "source": [ - "rendered_img = render_image_from_pose(tr,q)\n", + "rendered_img = render_image_from_pose(tr, q)\n", "b.clear()\n", - "b.show_cloud(\"cloud\", b.unproject_depth_jit(rendered_img, intrinsics).reshape(-1,3))\n", - "b.show_cloud(\"cloud2\", b.unproject_depth_jit(depths[timestep], intrinsics).reshape(-1,3),color=b.RED)\n", - "b.hstack_images(\n", - " [\n", - " b.get_depth_image(rendered_img),\n", - " b.get_depth_image(depths[timestep])\n", - " ]\n", - ")" + "b.show_cloud(\"cloud\", b.unproject_depth_jit(rendered_img, intrinsics).reshape(-1, 3))\n", + "b.show_cloud(\n", + " \"cloud2\",\n", + " b.unproject_depth_jit(depths[timestep], intrinsics).reshape(-1, 3),\n", + " color=b.RED,\n", + ")\n", + "b.hstack_images([b.get_depth_image(rendered_img), b.get_depth_image(depths[timestep])])" ] }, { "cell_type": "code", "execution_count": 98, + "id": "7cdc8c89c7104fffa095e18ddfef8986", "metadata": {}, "outputs": [ { @@ -308,16 +350,22 @@ } ], "source": [ - "mesh = point_cloud_image_to_trimesh(b.apply_transform(b.unproject_depth_jit(depths[timestep], intrinsics), b.translation_and_quaternion_to_pose_matrix(tr,q)))\n", + "mesh = point_cloud_image_to_trimesh(\n", + " b.apply_transform(\n", + " b.unproject_depth_jit(depths[timestep], intrinsics),\n", + " b.translation_and_quaternion_to_pose_matrix(tr, q),\n", + " )\n", + ")\n", "b.clear()\n", - "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)\n", - "b.show_trimesh(\"mesh\", mesh) \n", - "vertices,faces = jnp.array(mesh.vertices), jnp.array(mesh.faces)" + "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr, q), size=0.1)\n", + "b.show_trimesh(\"mesh\", mesh)\n", + "vertices, faces = jnp.array(mesh.vertices), jnp.array(mesh.faces)" ] }, { "cell_type": "code", "execution_count": 42, + "id": "b118ea5561624da68c537baed56e602f", "metadata": {}, "outputs": [ { @@ -359,12 +407,13 @@ } ], "source": [ - "tr,q = jnp.zeros(3), jnp.array([1.0, 0.0, 0.0, 0.0])\n", - "print(b.translation_and_quaternion_to_pose_matrix(tr,q))\n", - "img = render_image_from_pose(tr,q)\n", + "tr, q = jnp.zeros(3), jnp.array([1.0, 0.0, 0.0, 0.0])\n", + "print(b.translation_and_quaternion_to_pose_matrix(tr, q))\n", + "img = render_image_from_pose(tr, q)\n", "print(img.shape)\n", "b.get_depth_image(img)\n", "import matplotlib.pyplot as plt\n", + "\n", "plt.imshow(img)\n", "plt.colorbar()\n", "img" @@ -373,6 +422,7 @@ { "cell_type": "code", "execution_count": 35, + "id": "938c804e27f84196a10c8828c723f798", "metadata": {}, "outputs": [ { @@ -394,16 +444,17 @@ { "cell_type": "code", "execution_count": 32, + "id": "504fb2a444614c0babb325280ed9130a", "metadata": {}, "outputs": [], "source": [ "b.clear()\n", "cloud1 = b.unproject_depth_jit(original_depths[0], original_intrinsics)\n", - "b.show_cloud(\"1\", b.apply_transform_jit(cloud1,poses[0]).reshape(-1,3))\n", + "b.show_cloud(\"1\", b.apply_transform_jit(cloud1, poses[0]).reshape(-1, 3))\n", "\n", "T = 100\n", "cloud2 = b.unproject_depth_jit(original_depths[T], original_intrinsics)\n", - "b.show_cloud(\"2\", b.apply_transform_jit(cloud2, poses[T]).reshape(-1,3), color=b.RED)" + "b.show_cloud(\"2\", b.apply_transform_jit(cloud2, poses[T]).reshape(-1, 3), color=b.RED)" ] }, { @@ -447,10 +498,12 @@ "metadata": {}, "outputs": [], "source": [ - "j.hstack_images([\n", - " j.get_rgb_image(rgbs[T1]),\n", - " j.get_rgb_image(rgbs[T2]),\n", - "])" + "j.hstack_images(\n", + " [\n", + " j.get_rgb_image(rgbs[T1]),\n", + " j.get_rgb_image(rgbs[T2]),\n", + " ]\n", + ")" ] }, { @@ -460,10 +513,12 @@ "metadata": {}, "outputs": [], "source": [ - "j.hstack_images([\n", - " j.get_depth_image(point_cloud_image_1[:,:,2]),\n", - " j.get_depth_image(point_cloud_image_2[:,:,2]),\n", - "])" + "j.hstack_images(\n", + " [\n", + " j.get_depth_image(point_cloud_image_1[:, :, 2]),\n", + " j.get_depth_image(point_cloud_image_2[:, :, 2]),\n", + " ]\n", + ")" ] }, { @@ -473,7 +528,7 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = j.mesh.make_voxel_mesh_from_point_cloud(point_cloud_image_1.reshape(-1,3), 0.05)" + "mesh = j.mesh.make_voxel_mesh_from_point_cloud(point_cloud_image_1.reshape(-1, 3), 0.05)" ] }, { @@ -505,7 +560,7 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_trimesh(\"1\",mesh)" + "j.meshcat.show_trimesh(\"1\", mesh)" ] }, { @@ -516,11 +571,17 @@ "outputs": [], "source": [ "pose_estimate = jnp.eye(4)\n", - "NUM_SAMPLES_FOR_ESTIMATE=500\n", + "NUM_SAMPLES_FOR_ESTIMATE = 500\n", "keys = jax.random.split(jax.random.PRNGKey(4), NUM_SAMPLES_FOR_ESTIMATE)\n", - "var,conc = 0.01, 1000.0\n", - "get_proposals = jax.jit(jax.vmap(lambda key,pose_estimate,var,conc: j.distributions.gaussian_vmf_sample(\n", - " key, pose_estimate, var, conc),in_axes=(0,None,None,None)))\n", + "var, conc = 0.01, 1000.0\n", + "get_proposals = jax.jit(\n", + " jax.vmap(\n", + " lambda key, pose_estimate, var, conc: j.distributions.gaussian_vmf_sample(\n", + " key, pose_estimate, var, conc\n", + " ),\n", + " in_axes=(0, None, None, None),\n", + " )\n", + ")\n", "best_score = -jnp.inf" ] }, @@ -532,21 +593,23 @@ "outputs": [], "source": [ "for _ in tqdm(range(10)):\n", - " for (var,conc) in [(0.01, 2000.0),(0.001, 2000.0),(0.1, 600.0)]:\n", + " for var, conc in [(0.01, 2000.0), (0.001, 2000.0), (0.1, 600.0)]:\n", " pose_proposals = get_proposals(keys, pose_estimate, var, conc)\n", " keys = jax.random.split(keys[0], NUM_SAMPLES_FOR_ESTIMATE)\n", - " rendered_images = renderer.render_parallel(pose_proposals, 0)[...,:3]\n", - " weights = j.threedp3_likelihood_parallel_jit(point_cloud_image_2, rendered_images, 0.0001, 0.1, 1.0)\n", + " rendered_images = renderer.render_parallel(pose_proposals, 0)[..., :3]\n", + " weights = j.threedp3_likelihood_parallel_jit(\n", + " point_cloud_image_2, rendered_images, 0.0001, 0.1, 1.0\n", + " )\n", " if weights.max() > best_score:\n", " best_score = weights.max()\n", " pose_estimate = pose_proposals[weights.argmax()]\n", " print(best_score)\n", "\n", - " reconstruction = renderer.render_single_object(pose_estimate, 0)[:,:,:3]\n", + " reconstruction = renderer.render_single_object(pose_estimate, 0)[:, :, :3]\n", "\n", "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_image_2.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", reconstruction.reshape(-1,3), color=j.RED)\n" + "j.meshcat.show_cloud(\"1\", point_cloud_image_2.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", reconstruction.reshape(-1, 3), color=j.RED)" ] }, { @@ -556,7 +619,7 @@ "metadata": {}, "outputs": [], "source": [ - "j.get_depth_image(renderer.render_single_object(pose_estimate, 0)[...,2])" + "j.get_depth_image(renderer.render_single_object(pose_estimate, 0)[..., 2])" ] }, { @@ -566,7 +629,7 @@ "metadata": {}, "outputs": [], "source": [ - "j.get_depth_image(point_cloud_image_2[:,:,2])" + "j.get_depth_image(point_cloud_image_2[:, :, 2])" ] }, { @@ -588,7 +651,9 @@ "point_cloud_2 = j.t3d.unproject_depth_jit(depths[T2], intrinsics)\n", "\n", "correction_transform = j.t3d.inverse_pose(poses[T1]) @ poses[T2]\n", - "point_cloud_2_corrected = j.t3d.apply_transform_jit(j.t3d.unproject_depth_jit(depths[T2], intrinsics), correction_transform)" + "point_cloud_2_corrected = j.t3d.apply_transform_jit(\n", + " j.t3d.unproject_depth_jit(depths[T2], intrinsics), correction_transform\n", + ")" ] }, { @@ -599,13 +664,20 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1,3), color=j.RED)\n", + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1, 3), color=j.RED)\n", "\n", "R = jnp.array([0.0001])\n", "OUTLIER_PROB = 0.05\n", "OUTLIER_VOLUME = 1.0\n", - "j.threedp3_likelihood_jit(point_cloud_1, point_cloud_2, jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])), R,OUTLIER_PROB, OUTLIER_VOLUME)" + "j.threedp3_likelihood_jit(\n", + " point_cloud_1,\n", + " point_cloud_2,\n", + " jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])),\n", + " R,\n", + " OUTLIER_PROB,\n", + " OUTLIER_VOLUME,\n", + ")" ] }, { @@ -615,18 +687,14 @@ "metadata": {}, "outputs": [], "source": [ - "pc = point_cloud_1.reshape(-1,3)\n", + "pc = point_cloud_1.reshape(-1, 3)\n", "noise = jax.vmap(\n", - " lambda key: jax.random.multivariate_normal(\n", - " key, jnp.zeros(3), jnp.eye(3) * R[0]\n", - " )\n", - ")(\n", - " jax.random.split(jax.random.PRNGKey(3), pc.shape[0])\n", - ")\n", + " lambda key: jax.random.multivariate_normal(key, jnp.zeros(3), jnp.eye(3) * R[0])\n", + ")(jax.random.split(jax.random.PRNGKey(3), pc.shape[0]))\n", "pc_noisy = pc + noise\n", "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", pc.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", pc_noisy.reshape(-1,3), color=j.RED)" + "j.meshcat.show_cloud(\"1\", pc.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", pc_noisy.reshape(-1, 3), color=j.RED)" ] }, { @@ -637,12 +705,18 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", point_cloud_2_corrected.reshape(-1,3), color=j.RED)\n", + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", point_cloud_2_corrected.reshape(-1, 3), color=j.RED)\n", "\n", "j.meshcat.show_pose(\"pose\", correction_transform)\n", - "j.threedp3_likelihood_jit(point_cloud_1, j.t3d.apply_transform(point_cloud_2, \n", - " correction_transform), jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])), R,OUTLIER_PROB, OUTLIER_VOLUME)" + "j.threedp3_likelihood_jit(\n", + " point_cloud_1,\n", + " j.t3d.apply_transform(point_cloud_2, correction_transform),\n", + " jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])),\n", + " R,\n", + " OUTLIER_PROB,\n", + " OUTLIER_VOLUME,\n", + ")" ] }, { @@ -652,7 +726,7 @@ "metadata": {}, "outputs": [], "source": [ - "NUM_SAMPLES_FOR_ESTIMATE = 1000\n" + "NUM_SAMPLES_FOR_ESTIMATE = 1000" ] }, { @@ -672,36 +746,52 @@ "metadata": {}, "outputs": [], "source": [ - "threedp3_likelihood_parallel_jit = jax.jit(jax.vmap(\n", - " j.threedp3_likelihood,\n", - " in_axes=(None, 0, None, None, None, None)\n", - "))\n", + "threedp3_likelihood_parallel_jit = jax.jit(\n", + " jax.vmap(j.threedp3_likelihood, in_axes=(None, 0, None, None, None, None))\n", + ")\n", "\n", "\n", - "def refine_pose_estimate_inner(pose_estimate, point_cloud_1, point_cloud_2, keys, var, conc):\n", + "def refine_pose_estimate_inner(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, var, conc\n", + "):\n", " keys = jax.random.split(keys[0], NUM_SAMPLES_FOR_ESTIMATE)\n", - " pose_proposals = jax.vmap(lambda key: j.distributions.gaussian_vmf_sample(\n", - " key, pose_estimate, var, conc))(\n", - " keys\n", - " )\n", + " pose_proposals = jax.vmap(\n", + " lambda key: j.distributions.gaussian_vmf_sample(key, pose_estimate, var, conc)\n", + " )(keys)\n", "\n", " rendered_images = jnp.einsum(\n", - " 'aij,...j->a...i',\n", + " \"aij,...j->a...i\",\n", " pose_proposals,\n", - " jnp.concatenate([point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1),\n", + " jnp.concatenate(\n", + " [point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1\n", + " ),\n", " )[..., :-1]\n", "\n", " best_score = j.threedp3_likelihood_jit(\n", - " point_cloud_1, j.t3d.apply_transform(point_cloud_2, \n", - " pose_estimate), jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])), R, OUTLIER_PROB, OUTLIER_VOLUME)\n", - " \n", + " point_cloud_1,\n", + " j.t3d.apply_transform(point_cloud_2, pose_estimate),\n", + " jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])),\n", + " R,\n", + " OUTLIER_PROB,\n", + " OUTLIER_VOLUME,\n", + " )\n", + "\n", " weights = threedp3_likelihood_parallel_jit(\n", - " point_cloud_1, rendered_images, jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])), R,OUTLIER_PROB, OUTLIER_VOLUME)\n", + " point_cloud_1,\n", + " rendered_images,\n", + " jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])),\n", + " R,\n", + " OUTLIER_PROB,\n", + " OUTLIER_VOLUME,\n", + " )\n", " weights_max = weights.max()\n", - " better = (weights_max > best_score)\n", - " pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate* (1.0 - better)\n", + " better = weights_max > best_score\n", + " pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate * (\n", + " 1.0 - better\n", + " )\n", " return pose_estimate, keys\n", "\n", + "\n", "refine_pose_estimate_jit = jax.jit(refine_pose_estimate_inner)" ] }, @@ -713,13 +803,24 @@ "outputs": [], "source": [ "for _ in tqdm(range(20)):\n", - " pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.01, 1000.0)\n", - " pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.005, 2000.0)\n", - " pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0)\n", + " pose_estimate, keys = refine_pose_estimate_jit(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, 0.01, 1000.0\n", + " )\n", + " pose_estimate, keys = refine_pose_estimate_jit(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, 0.005, 2000.0\n", + " )\n", + " pose_estimate, keys = refine_pose_estimate_jit(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0\n", + " )\n", "score = j.threedp3_likelihood_jit(\n", - " point_cloud_1, j.t3d.apply_transform(point_cloud_2, \n", - " pose_estimate), jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])), R, OUTLIER_PROB, OUTLIER_VOLUME)\n", - "keys[0],score" + " point_cloud_1,\n", + " j.t3d.apply_transform(point_cloud_2, pose_estimate),\n", + " jnp.zeros((point_cloud_2.shape[0], point_cloud_2.shape[1])),\n", + " R,\n", + " OUTLIER_PROB,\n", + " OUTLIER_VOLUME,\n", + ")\n", + "keys[0], score" ] }, { @@ -730,8 +831,10 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", j.t3d.apply_transform(point_cloud_2, pose_estimate).reshape(-1,3), color=j.RED)" + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\n", + " \"2\", j.t3d.apply_transform(point_cloud_2, pose_estimate).reshape(-1, 3), color=j.RED\n", + ")" ] }, { @@ -750,8 +853,8 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1,3), color=j.RED)" + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1, 3), color=j.RED)" ] }, { @@ -777,19 +880,24 @@ "metadata": {}, "outputs": [], "source": [ - "point_clouds = [\n", - " unproject_depth(depths[t], intrinsics)\n", - " for t in jnp.arange(200, 300, 5)\n", - "]\n", + "point_clouds = [unproject_depth(depths[t], intrinsics) for t in jnp.arange(200, 300, 5)]\n", "transforms = []\n", "for i in tqdm(range(len(point_clouds) - 1)):\n", " pose_estimate = jnp.eye(4)\n", - " point_cloud_1, point_cloud_2 = point_clouds[i], point_clouds[i+1]\n", - " pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.01, 1000.0)\n", - " pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.005, 2000.0)\n", - " pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0)\n", - " pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0)\n", - " transforms.append(pose_estimate)\n" + " point_cloud_1, point_cloud_2 = point_clouds[i], point_clouds[i + 1]\n", + " pose_estimate, keys = refine_pose_estimate_jit(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, 0.01, 1000.0\n", + " )\n", + " pose_estimate, keys = refine_pose_estimate_jit(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, 0.005, 2000.0\n", + " )\n", + " pose_estimate, keys = refine_pose_estimate_jit(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0\n", + " )\n", + " pose_estimate, keys = refine_pose_estimate_jit(\n", + " pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0\n", + " )\n", + " transforms.append(pose_estimate)" ] }, { @@ -810,8 +918,8 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_clouds[i].reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", point_clouds[i+1].reshape(-1,3), color=j.RED)\n" + "j.meshcat.show_cloud(\"1\", point_clouds[i].reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", point_clouds[i + 1].reshape(-1, 3), color=j.RED)" ] }, { @@ -822,8 +930,12 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_clouds[i].reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", j.t3d.apply_transform(point_clouds[i+1], transforms[i]).reshape(-1,3), color=j.RED)\n" + "j.meshcat.show_cloud(\"1\", point_clouds[i].reshape(-1, 3))\n", + "j.meshcat.show_cloud(\n", + " \"2\",\n", + " j.t3d.apply_transform(point_clouds[i + 1], transforms[i]).reshape(-1, 3),\n", + " color=j.RED,\n", + ")" ] }, { @@ -852,7 +964,7 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n" + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))" ] }, { @@ -863,7 +975,9 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"2\", j.t3d.apply_transform(point_cloud_2, pose_estimate).reshape(-1,3), color=j.RED)" + "j.meshcat.show_cloud(\n", + " \"2\", j.t3d.apply_transform(point_cloud_2, pose_estimate).reshape(-1, 3), color=j.RED\n", + ")" ] }, { @@ -900,7 +1014,9 @@ "metadata": {}, "outputs": [], "source": [ - "recontruction = renderer.render_single_object(j.t3d.inverse_pose(poses[T2]) @ poses[T1], 0)" + "recontruction = renderer.render_single_object(\n", + " j.t3d.inverse_pose(poses[T2]) @ poses[T1], 0\n", + ")" ] }, { @@ -930,8 +1046,8 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1,3), color=j.RED)" + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1, 3), color=j.RED)" ] }, { @@ -951,12 +1067,10 @@ "source": [ "T_WIDTH = 0.01\n", "translation_grid = j.make_translation_grid_enumeration(\n", - " -T_WIDTH,-T_WIDTH,-T_WIDTH,\n", - " T_WIDTH,T_WIDTH,T_WIDTH,\n", - " 11,11,11\n", + " -T_WIDTH, -T_WIDTH, -T_WIDTH, T_WIDTH, T_WIDTH, T_WIDTH, 11, 11, 11\n", ")\n", "rotation_grid = j.make_rotation_grid_enumeration(\n", - " 50, 40, -jnp.pi/40, jnp.pi/40, jnp.pi/40\n", + " 50, 40, -jnp.pi / 40, jnp.pi / 40, jnp.pi / 40\n", ")\n", "\n", "\n", @@ -972,42 +1086,49 @@ "outputs": [], "source": [ "pose_proposals = jnp.einsum(\n", - " 'aij,jk->aik',\n", + " \"aij,jk->aik\",\n", " translation_grid,\n", " pose_estimate,\n", - " \n", ")\n", "rendered_images = jnp.einsum(\n", - " 'aij,...j->a...i',\n", + " \"aij,...j->a...i\",\n", " pose_proposals,\n", - " jnp.concatenate([point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1),\n", + " jnp.concatenate(\n", + " [point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1\n", + " ),\n", ")[..., :-1]\n", "\n", - "weights = j.threedp3_likelihood_parallel_jit(point_cloud_1, rendered_images, R, OUTLIER_PROB, OUTLIER_VOLUME)\n", + "weights = j.threedp3_likelihood_parallel_jit(\n", + " point_cloud_1, rendered_images, R, OUTLIER_PROB, OUTLIER_VOLUME\n", + ")\n", "weights_max = weights.max()\n", - "better = (weights_max > best_score)\n", - "pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate* (1.0 - better)\n", + "better = weights_max > best_score\n", + "pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate * (\n", + " 1.0 - better\n", + ")\n", "best_score = weights_max * better + best_score * (1.0 - better)\n", "print(best_score)\n", "\n", - "pose_proposals = jnp.einsum(\n", - " 'ij,ajk->aik',\n", - " pose_estimate,\n", - " rotation_grid\n", - ")\n", + "pose_proposals = jnp.einsum(\"ij,ajk->aik\", pose_estimate, rotation_grid)\n", "rendered_images = jnp.einsum(\n", - " 'aij,...j->a...i',\n", + " \"aij,...j->a...i\",\n", " pose_proposals,\n", - " jnp.concatenate([point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1),\n", + " jnp.concatenate(\n", + " [point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1\n", + " ),\n", ")[..., :-1]\n", "\n", "\n", - "weights = j.threedp3_likelihood_parallel_jit(point_cloud_1, rendered_images, R, OUTLIER_PROB, OUTLIER_VOLUME)\n", + "weights = j.threedp3_likelihood_parallel_jit(\n", + " point_cloud_1, rendered_images, R, OUTLIER_PROB, OUTLIER_VOLUME\n", + ")\n", "weights_max = weights.max()\n", - "better = (weights_max > best_score)\n", - "pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate *(1.0 - better)\n", + "better = weights_max > best_score\n", + "pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate * (\n", + " 1.0 - better\n", + ")\n", "best_score = weights_max * better + best_score * (1.0 - better)\n", - "print(best_score)\n" + "print(best_score)" ] }, { @@ -1030,7 +1151,7 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n" + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))" ] }, { @@ -1041,8 +1162,8 @@ "outputs": [], "source": [ "j.meshcat.clear()\n", - "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1,3))\n", - "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1,3), color=j.RED)" + "j.meshcat.show_cloud(\"1\", point_cloud_1.reshape(-1, 3))\n", + "j.meshcat.show_cloud(\"2\", point_cloud_2.reshape(-1, 3), color=j.RED)" ] }, { diff --git a/scripts/experiments/slam/slam_2d.ipynb b/scripts/experiments/slam/slam_2d.ipynb index e4e55fa5..a4f47b45 100644 --- a/scripts/experiments/slam/slam_2d.ipynb +++ b/scripts/experiments/slam/slam_2d.ipynb @@ -20,2634 +20,732 @@ "from genjax._src.core.transforms.incremental import NoChange\n", "from genjax._src.core.transforms.incremental import Diff\n", "import trimesh\n", - "console = genjax.pretty(show_locals=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25421e64-5f56-4cbe-aa20-104e7400d8e6", - "metadata": {}, - "outputs": [], - "source": [ - "intrinsics = b.Intrinsics(\n", - " height=1,\n", - " width=50,\n", - " fx=10.0, fy=1.0,\n", - " cx=25.0, cy=0.0,\n", - " near=0.01, far=20.0\n", - ")\n", - "b.setup_renderer(intrinsics)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00c9ee51-8744-470b-aeac-59ccb04d0ca6", - "metadata": {}, - "outputs": [], - "source": [ - "b.setup_visualizer()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f959c4d7-f2c1-4449-8513-cdba7a4e036a", - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "env_data = {\n", - " \"env_name\": \"014e6d1cee6e6a1297a78f761fbc6700.json\",\n", - " \"paths\": [\n", - " [\n", - " [\n", - " 15.13554761904762,\n", - " 14.694047619047623\n", - " ],\n", - " [\n", - " 14.821214285714289,\n", - " 14.604238095238099\n", - " ],\n", - " [\n", - " 14.551785714285716,\n", - " 14.469523809523814\n", - " ],\n", - " [\n", - " 14.372166666666667,\n", - " 14.245000000000005\n", - " ],\n", - " [\n", - " 14.237452380952382,\n", - " 13.975571428571431\n", - " ],\n", - " [\n", - " 14.102738095238097,\n", - " 13.706142857142861\n", - " ],\n", - " [\n", - " 14.057833333333337,\n", - " 13.346904761904764\n", - " ],\n", - " [\n", - " 14.012928571428573,\n", - " 13.077476190476194\n", - " ],\n", - " [\n", - " 13.878214285714288,\n", - " 12.80804761904762\n", - " ],\n", - " [\n", - " 13.788404761904763,\n", - " 12.359000000000002\n", - " ],\n", - " [\n", - " 13.608785714285714,\n", - " 11.999761904761908\n", - " ],\n", - " [\n", - " 13.518976190476193,\n", - " 11.685428571428574\n", - " ],\n", - " [\n", - " 13.47407142857143,\n", - " 11.371095238095242\n", - " ],\n", - " [\n", - " 13.384261904761905,\n", - " 11.056761904761908\n", - " ],\n", - " [\n", - " 13.24954761904762,\n", - " 10.787333333333336\n", - " ],\n", - " [\n", - " 12.890309523809526,\n", - " 10.60771428571429\n", - " ],\n", - " [\n", - " 12.665785714285716,\n", - " 10.60771428571429\n", - " ],\n", - " [\n", - " 12.261642857142858,\n", - " 10.562809523809527\n", - " ],\n", - " [\n", - " 11.8575,\n", - " 10.517904761904765\n", - " ],\n", - " [\n", - " 11.408452380952381,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 10.9145,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 10.42054761904762,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 10.061309523809525,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 9.567357142857144,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 9.118309523809524,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 8.714166666666667,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 8.175309523809524,\n", - " 10.60771428571429\n", - " ],\n", - " [\n", - " 7.726261904761905,\n", - " 10.87714285714286\n", - " ],\n", - " [\n", - " 7.367023809523809,\n", - " 11.281285714285717\n", - " ],\n", - " [\n", - " 7.187404761904762,\n", - " 11.730333333333336\n", - " ],\n", - " [\n", - " 7.052690476190477,\n", - " 12.224285714285717\n", - " ],\n", - " [\n", - " 7.052690476190477,\n", - " 12.673333333333336\n", - " ],\n", - " [\n", - " 7.097595238095238,\n", - " 12.942761904761909\n", - " ],\n", - " [\n", - " 7.187404761904762,\n", - " 13.302000000000003\n", - " ],\n", - " [\n", - " 7.2323095238095245,\n", - " 13.706142857142861\n", - " ],\n", - " [\n", - " 7.097595238095238,\n", - " 14.065380952380956\n", - " ],\n", - " [\n", - " 6.917976190476191,\n", - " 14.42461904761905\n", - " ],\n", - " [\n", - " 6.558738095238095,\n", - " 14.694047619047623\n", - " ],\n", - " [\n", - " 6.019880952380953,\n", - " 14.918571428571433\n", - " ],\n", - " [\n", - " 5.660642857142857,\n", - " 15.143095238095242\n", - " ],\n", - " [\n", - " 5.256500000000001,\n", - " 15.367619047619051\n", - " ],\n", - " [\n", - " 5.076880952380952,\n", - " 15.59214285714286\n", - " ],\n", - " [\n", - " 4.852357142857143,\n", - " 15.81666666666667\n", - " ],\n", - " [\n", - " 4.582928571428571,\n", - " 16.22080952380953\n", - " ],\n", - " [\n", - " 4.3134999999999994,\n", - " 16.580047619047622\n", - " ],\n", - " [\n", - " 3.954261904761905,\n", - " 16.75966666666667\n", - " ],\n", - " [\n", - " 3.550119047619047,\n", - " 16.939285714285717\n", - " ],\n", - " [\n", - " 3.1459761904761905,\n", - " 17.02909523809524\n", - " ],\n", - " [\n", - " 2.741833333333333,\n", - " 17.02909523809524\n", - " ],\n", - " [\n", - " 2.2927857142857135,\n", - " 17.074000000000005\n", - " ],\n", - " [\n", - " 1.8437380952380948,\n", - " 16.98419047619048\n", - " ]\n", - " ],\n", - " [\n", - " [\n", - " 18.278880952380952,\n", - " 14.514428571428574\n", - " ],\n", - " [\n", - " 17.91964285714286,\n", - " 14.334809523809525\n", - " ],\n", - " [\n", - " 17.515500000000003,\n", - " 14.110285714285716\n", - " ],\n", - " [\n", - " 17.290976190476194,\n", - " 13.795952380952386\n", - " ],\n", - " [\n", - " 17.156261904761905,\n", - " 13.571428571428573\n", - " ],\n", - " [\n", - " 16.931738095238096,\n", - " 13.122380952380954\n", - " ],\n", - " [\n", - " 16.662309523809526,\n", - " 12.583523809523811\n", - " ],\n", - " [\n", - " 16.572500000000005,\n", - " 12.26919047619048\n", - " ],\n", - " [\n", - " 16.527595238095238,\n", - " 11.909952380952383\n", - " ],\n", - " [\n", - " 16.437785714285717,\n", - " 11.550714285714289\n", - " ],\n", - " [\n", - " 16.437785714285717,\n", - " 11.191476190476193\n", - " ],\n", - " [\n", - " 16.303071428571428,\n", - " 10.832238095238099\n", - " ],\n", - " [\n", - " 15.854023809523811,\n", - " 10.562809523809527\n", - " ],\n", - " [\n", - " 15.539690476190477,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 14.955928571428574,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 14.41707142857143,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 13.833309523809524,\n", - " 10.338285714285718\n", - " ],\n", - " [\n", - " 13.339357142857144,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 12.800500000000001,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 12.306547619047619,\n", - " 10.383190476190478\n", - " ],\n", - " [\n", - " 11.8575,\n", - " 10.383190476190478\n", - " ],\n", - " [\n", - " 11.453357142857143,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 10.9145,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 10.510357142857144,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 10.151119047619048,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 9.746976190476191,\n", - " 10.42809523809524\n", - " ],\n", - " [\n", - " 9.208119047619048,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 8.53454761904762,\n", - " 10.473000000000003\n", - " ],\n", - " [\n", - " 8.130404761904762,\n", - " 10.742428571428574\n", - " ],\n", - " [\n", - " 7.636452380952381,\n", - " 11.371095238095242\n", - " ],\n", - " [\n", - " 7.411928571428572,\n", - " 11.865047619047623\n", - " ],\n", - " [\n", - " 7.187404761904762,\n", - " 12.583523809523811\n", - " ],\n", - " [\n", - " 7.1425,\n", - " 13.167285714285718\n", - " ],\n", - " [\n", - " 7.097595238095238,\n", - " 13.706142857142861\n", - " ],\n", - " [\n", - " 7.052690476190477,\n", - " 14.20009523809524\n", - " ],\n", - " [\n", - " 7.007785714285715,\n", - " 14.604238095238099\n", - " ],\n", - " [\n", - " 6.648547619047619,\n", - " 15.053285714285717\n", - " ]\n", - " ],\n", - " [\n", - " [\n", - " 14.821214285714289,\n", - " 14.604238095238099\n", - " ],\n", - " [\n", - " 14.461976190476191,\n", - " 14.42461904761905\n", - " ],\n", - " [\n", - " 14.147642857142857,\n", - " 14.020476190476195\n", - " ],\n", - " [\n", - " 13.923119047619048,\n", - " 13.481619047619052\n", - " ],\n", - " [\n", - " 13.743500000000003,\n", - " 12.80804761904762\n", - " ],\n", - " [\n", - " 13.608785714285714,\n", - " 12.179380952380955\n", - " ],\n", - " [\n", - " 13.518976190476193,\n", - " 11.640523809523813\n", - " ],\n", - " [\n", - " 13.47407142857143,\n", - " 11.191476190476193\n", - " ],\n", - " [\n", - " 13.384261904761905,\n", - " 10.832238095238099\n", - " ],\n", - " [\n", - " 13.02502380952381,\n", - " 10.562809523809527\n", - " ],\n", - " [\n", - " 12.261642857142858,\n", - " 10.517904761904765\n", - " ],\n", - " [\n", - " 11.8575,\n", - " 11.056761904761908\n", - " ],\n", - " [\n", - " 11.767690476190479,\n", - " 11.82014285714286\n", - " ],\n", - " [\n", - " 11.722785714285715,\n", - " 12.403904761904766\n", - " ],\n", - " [\n", - " 11.58807142857143,\n", - " 13.122380952380954\n", - " ],\n", - " [\n", - " 11.408452380952381,\n", - " 13.751047619047622\n", - " ],\n", - " [\n", - " 11.09411904761905,\n", - " 14.334809523809525\n", - " ],\n", - " [\n", - " 10.555261904761906,\n", - " 14.64914285714286\n", - " ]\n", - " ],\n", - " []\n", - " ],\n", - " \"verts\": [\n", - " [\n", - " 13.24,\n", - " 0.1\n", - " ],\n", - " [\n", - " 13.23,\n", - " 0.11\n", - " ],\n", - " [\n", - " 13.23,\n", - " 5.67\n", - " ],\n", - " [\n", - " 13.31,\n", - " 5.67\n", - " ],\n", - " [\n", - " 13.32,\n", - " 5.68\n", - " ],\n", - " [\n", - " 13.32,\n", - " 5.78\n", - " ],\n", - " [\n", - " 13.31,\n", - " 5.79\n", - " ],\n", - " [\n", - " 12.24,\n", - " 5.79\n", - " ],\n", - " [\n", - " 12.23,\n", - " 5.8\n", - " ],\n", - " [\n", - " 12.23,\n", - " 9.57\n", - " ],\n", - " [\n", - " 14.52,\n", - " 9.57\n", - " ],\n", - " [\n", - " 14.53,\n", - " 9.58\n", - " ],\n", - " [\n", - " 14.53,\n", - " 9.68\n", - " ],\n", - " [\n", - " 14.52,\n", - " 9.69\n", - " ],\n", - " [\n", - " 8.93,\n", - " 9.69\n", - " ],\n", - " [\n", - " 8.93,\n", - " 9.85\n", - " ],\n", - " [\n", - " 8.92,\n", - " 9.86\n", - " ],\n", - " [\n", - " 8.82,\n", - " 9.86\n", - " ],\n", - " [\n", - " 8.81,\n", - " 9.85\n", - " ],\n", - " [\n", - " 8.81,\n", - " 9.69\n", - " ],\n", - " [\n", - " 5.98,\n", - " 9.69\n", - " ],\n", - " [\n", - " 5.97,\n", - " 9.7\n", - " ],\n", - " [\n", - " 5.96,\n", - " 9.7\n", - " ],\n", - " [\n", - " 5.95,\n", - " 9.69\n", - " ],\n", - " [\n", - " 5.63,\n", - " 9.69\n", - " ],\n", - " [\n", - " 5.62,\n", - " 9.68\n", - " ],\n", - " [\n", - " 5.62,\n", - " 9.58\n", - " ],\n", - " [\n", - " 5.63,\n", - " 9.57\n", - " ],\n", - " [\n", - " 5.87,\n", - " 9.57\n", - " ],\n", - " [\n", - " 5.87,\n", - " 5.79\n", - " ],\n", - " [\n", - " 1.94,\n", - " 5.79\n", - " ],\n", - " [\n", - " 1.93,\n", - " 5.8\n", - " ],\n", - " [\n", - " 1.93,\n", - " 9.07\n", - " ],\n", - " [\n", - " 4.51,\n", - " 9.07\n", - " ],\n", - " [\n", - " 4.53,\n", - " 9.09\n", - " ],\n", - " [\n", - " 4.53,\n", - " 9.57\n", - " ],\n", - " [\n", - " 4.77,\n", - " 9.57\n", - " ],\n", - " [\n", - " 4.78,\n", - " 9.58\n", - " ],\n", - " [\n", - " 4.78,\n", - " 9.68\n", - " ],\n", - " [\n", - " 4.77,\n", - " 9.69\n", - " ],\n", - " [\n", - " 4.53,\n", - " 9.69\n", - " ],\n", - " [\n", - " 4.53,\n", - " 10.14\n", - " ],\n", - " [\n", - " 4.52,\n", - " 10.15\n", - " ],\n", - " [\n", - " 4.42,\n", - " 10.15\n", - " ],\n", - " [\n", - " 4.41,\n", - " 10.14\n", - " ],\n", - " [\n", - " 4.41,\n", - " 9.19\n", - " ],\n", - " [\n", - " 1.93,\n", - " 9.19\n", - " ],\n", - " [\n", - " 1.93,\n", - " 11.07\n", - " ],\n", - " [\n", - " 4.41,\n", - " 11.07\n", - " ],\n", - " [\n", - " 4.41,\n", - " 11.0\n", - " ],\n", - " [\n", - " 4.42,\n", - " 10.99\n", - " ],\n", - " [\n", - " 4.52,\n", - " 10.99\n", - " ],\n", - " [\n", - " 4.53,\n", - " 11.0\n", - " ],\n", - " [\n", - " 4.53,\n", - " 11.07\n", - " ],\n", - " [\n", - " 5.32,\n", - " 11.07\n", - " ],\n", - " [\n", - " 5.33,\n", - " 11.06\n", - " ],\n", - " [\n", - " 5.33,\n", - " 11.05\n", - " ],\n", - " [\n", - " 5.34,\n", - " 11.04\n", - " ],\n", - " [\n", - " 5.36,\n", - " 11.04\n", - " ],\n", - " [\n", - " 5.37,\n", - " 11.05\n", - " ],\n", - " [\n", - " 5.38,\n", - " 11.05\n", - " ],\n", - " [\n", - " 5.39,\n", - " 11.06\n", - " ],\n", - " [\n", - " 5.4,\n", - " 11.06\n", - " ],\n", - " [\n", - " 5.41,\n", - " 11.07\n", - " ],\n", - " [\n", - " 5.42,\n", - " 11.07\n", - " ],\n", - " [\n", - " 5.43,\n", - " 11.08\n", - " ],\n", - " [\n", - " 5.44,\n", - " 11.08\n", - " ],\n", - " [\n", - " 5.45,\n", - " 11.09\n", - " ],\n", - " [\n", - " 5.43,\n", - " 11.11\n", - " ],\n", - " [\n", - " 5.43,\n", - " 11.18\n", - " ],\n", - " [\n", - " 5.42,\n", - " 11.19\n", - " ],\n", - " [\n", - " 5.42,\n", - " 13.57\n", - " ],\n", - " [\n", - " 6.46,\n", - " 13.57\n", - " ],\n", - " [\n", - " 6.47,\n", - " 13.58\n", - " ],\n", - " [\n", - " 6.47,\n", - " 13.68\n", - " ],\n", - " [\n", - " 6.46,\n", - " 13.69\n", - " ],\n", - " [\n", - " 4.19,\n", - " 13.69\n", - " ],\n", - " [\n", - " 4.18,\n", - " 13.68\n", - " ],\n", - " [\n", - " 4.18,\n", - " 13.58\n", - " ],\n", - " [\n", - " 4.19,\n", - " 13.57\n", - " ],\n", - " [\n", - " 5.31,\n", - " 13.57\n", - " ],\n", - " [\n", - " 5.31,\n", - " 11.19\n", - " ],\n", - " [\n", - " 1.93,\n", - " 11.19\n", - " ],\n", - " [\n", - " 1.93,\n", - " 13.57\n", - " ],\n", - " [\n", - " 3.23,\n", - " 13.57\n", - " ],\n", - " [\n", - " 3.24,\n", - " 13.58\n", - " ],\n", - " [\n", - " 3.24,\n", - " 13.68\n", - " ],\n", - " [\n", - " 3.23,\n", - " 13.69\n", - " ],\n", - " [\n", - " 1.89,\n", - " 13.69\n", - " ],\n", - " [\n", - " 1.88,\n", - " 13.7\n", - " ],\n", - " [\n", - " 1.87,\n", - " 13.7\n", - " ],\n", - " [\n", - " 1.85,\n", - " 13.72\n", - " ],\n", - " [\n", - " 1.84,\n", - " 13.72\n", - " ],\n", - " [\n", - " 1.82,\n", - " 13.74\n", - " ],\n", - " [\n", - " 1.81,\n", - " 13.74\n", - " ],\n", - " [\n", - " 1.79,\n", - " 13.76\n", - " ],\n", - " [\n", - " 1.78,\n", - " 13.76\n", - " ],\n", - " [\n", - " 1.76,\n", - " 13.78\n", - " ],\n", - " [\n", - " 1.75,\n", - " 13.78\n", - " ],\n", - " [\n", - " 1.73,\n", - " 13.8\n", - " ],\n", - " [\n", - " 1.72,\n", - " 13.8\n", - " ],\n", - " [\n", - " 1.71,\n", - " 13.81\n", - " ],\n", - " [\n", - " 1.7,\n", - " 13.81\n", - " ],\n", - " [\n", - " 1.68,\n", - " 13.83\n", - " ],\n", - " [\n", - " 1.67,\n", - " 13.83\n", - " ],\n", - " [\n", - " 1.65,\n", - " 13.85\n", - " ],\n", - " [\n", - " 1.64,\n", - " 13.85\n", - " ],\n", - " [\n", - " 1.62,\n", - " 13.87\n", - " ],\n", - " [\n", - " 1.61,\n", - " 13.87\n", - " ],\n", - " [\n", - " 1.59,\n", - " 13.89\n", - " ],\n", - " [\n", - " 1.58,\n", - " 13.89\n", - " ],\n", - " [\n", - " 1.56,\n", - " 13.91\n", - " ],\n", - " [\n", - " 1.55,\n", - " 13.91\n", - " ],\n", - " [\n", - " 1.53,\n", - " 13.93\n", - " ],\n", - " [\n", - " 1.52,\n", - " 13.93\n", - " ],\n", - " [\n", - " 1.5,\n", - " 13.95\n", - " ],\n", - " [\n", - " 1.49,\n", - " 13.95\n", - " ],\n", - " [\n", - " 1.47,\n", - " 13.97\n", - " ],\n", - " [\n", - " 1.46,\n", - " 13.97\n", - " ],\n", - " [\n", - " 1.44,\n", - " 13.99\n", - " ],\n", - " [\n", - " 1.43,\n", - " 13.99\n", - " ],\n", - " [\n", - " 1.41,\n", - " 14.01\n", - " ],\n", - " [\n", - " 1.4,\n", - " 14.01\n", - " ],\n", - " [\n", - " 1.38,\n", - " 14.03\n", - " ],\n", - " [\n", - " 1.37,\n", - " 14.03\n", - " ],\n", - " [\n", - " 1.35,\n", - " 14.05\n", - " ],\n", - " [\n", - " 1.34,\n", - " 14.05\n", - " ],\n", - " [\n", - " 1.32,\n", - " 14.07\n", - " ],\n", - " [\n", - " 1.31,\n", - " 14.07\n", - " ],\n", - " [\n", - " 1.29,\n", - " 14.09\n", - " ],\n", - " [\n", - " 1.28,\n", - " 14.09\n", - " ],\n", - " [\n", - " 1.27,\n", - " 14.1\n", - " ],\n", - " [\n", - " 1.26,\n", - " 14.1\n", - " ],\n", - " [\n", - " 1.24,\n", - " 14.12\n", - " ],\n", - " [\n", - " 1.23,\n", - " 14.12\n", - " ],\n", - " [\n", - " 1.21,\n", - " 14.14\n", - " ],\n", - " [\n", - " 1.2,\n", - " 14.14\n", - " ],\n", - " [\n", - " 1.18,\n", - " 14.16\n", - " ],\n", - " [\n", - " 1.17,\n", - " 14.16\n", - " ],\n", - " [\n", - " 1.15,\n", - " 14.18\n", - " ],\n", - " [\n", - " 1.14,\n", - " 14.18\n", - " ],\n", - " [\n", - " 1.12,\n", - " 14.2\n", - " ],\n", - " [\n", - " 1.11,\n", - " 14.2\n", - " ],\n", - " [\n", - " 1.09,\n", - " 14.22\n", - " ],\n", - " [\n", - " 1.08,\n", - " 14.22\n", - " ],\n", - " [\n", - " 1.06,\n", - " 14.24\n", - " ],\n", - " [\n", - " 1.05,\n", - " 14.24\n", - " ],\n", - " [\n", - " 1.03,\n", - " 14.26\n", - " ],\n", - " [\n", - " 1.02,\n", - " 14.26\n", - " ],\n", - " [\n", - " 1.0,\n", - " 14.28\n", - " ],\n", - " [\n", - " 0.99,\n", - " 14.28\n", - " ],\n", - " [\n", - " 0.97,\n", - " 14.3\n", - " ],\n", - " [\n", - " 0.96,\n", - " 14.3\n", - " ],\n", - " [\n", - " 0.94,\n", - " 14.32\n", - " ],\n", - " [\n", - " 0.93,\n", - " 14.32\n", - " ],\n", - " [\n", - " 0.91,\n", - " 14.34\n", - " ],\n", - " [\n", - " 0.9,\n", - " 14.34\n", - " ],\n", - " [\n", - " 0.88,\n", - " 14.36\n", - " ],\n", - " [\n", - " 0.87,\n", - " 14.36\n", - " ],\n", - " [\n", - " 0.85,\n", - " 14.38\n", - " ],\n", - " [\n", - " 0.84,\n", - " 14.38\n", - " ],\n", - " [\n", - " 0.83,\n", - " 14.39\n", - " ],\n", - " [\n", - " 0.82,\n", - " 14.39\n", - " ],\n", - " [\n", - " 0.8,\n", - " 14.41\n", - " ],\n", - " [\n", - " 0.79,\n", - " 14.41\n", - " ],\n", - " [\n", - " 0.77,\n", - " 14.43\n", - " ],\n", - " [\n", - " 0.76,\n", - " 14.43\n", - " ],\n", - " [\n", - " 0.74,\n", - " 14.45\n", - " ],\n", - " [\n", - " 0.73,\n", - " 14.45\n", - " ],\n", - " [\n", - " 0.71,\n", - " 14.47\n", - " ],\n", - " [\n", - " 0.7,\n", - " 14.47\n", - " ],\n", - " [\n", - " 0.68,\n", - " 14.49\n", - " ],\n", - " [\n", - " 0.67,\n", - " 14.49\n", - " ],\n", - " [\n", - " 0.65,\n", - " 14.51\n", - " ],\n", - " [\n", - " 0.64,\n", - " 14.51\n", - " ],\n", - " [\n", - " 0.62,\n", - " 14.53\n", - " ],\n", - " [\n", - " 0.61,\n", - " 14.53\n", - " ],\n", - " [\n", - " 0.59,\n", - " 14.55\n", - " ],\n", - " [\n", - " 0.58,\n", - " 14.55\n", - " ],\n", - " [\n", - " 0.56,\n", - " 14.57\n", - " ],\n", - " [\n", - " 0.55,\n", - " 14.57\n", - " ],\n", - " [\n", - " 0.53,\n", - " 14.59\n", - " ],\n", - " [\n", - " 0.52,\n", - " 14.59\n", - " ],\n", - " [\n", - " 0.5,\n", - " 14.61\n", - " ],\n", - " [\n", - " 0.49,\n", - " 14.61\n", - " ],\n", - " [\n", - " 0.47,\n", - " 14.63\n", - " ],\n", - " [\n", - " 0.46,\n", - " 14.63\n", - " ],\n", - " [\n", - " 0.44,\n", - " 14.65\n", - " ],\n", - " [\n", - " 0.43,\n", - " 14.65\n", - " ],\n", - " [\n", - " 0.41,\n", - " 14.67\n", - " ],\n", - " [\n", - " 0.4,\n", - " 14.67\n", - " ],\n", - " [\n", - " 0.39,\n", - " 14.68\n", - " ],\n", - " [\n", - " 0.38,\n", - " 14.68\n", - " ],\n", - " [\n", - " 0.36,\n", - " 14.7\n", - " ],\n", - " [\n", - " 0.35,\n", - " 14.7\n", - " ],\n", - " [\n", - " 0.33,\n", - " 14.72\n", - " ],\n", - " [\n", - " 0.32,\n", - " 14.72\n", - " ],\n", - " [\n", - " 0.3,\n", - " 14.74\n", - " ],\n", - " [\n", - " 0.29,\n", - " 14.74\n", - " ],\n", - " [\n", - " 0.27,\n", - " 14.76\n", - " ],\n", - " [\n", - " 0.26,\n", - " 14.76\n", - " ],\n", - " [\n", - " 0.24,\n", - " 14.78\n", - " ],\n", - " [\n", - " 0.23,\n", - " 14.78\n", - " ],\n", - " [\n", - " 0.21,\n", - " 14.8\n", - " ],\n", - " [\n", - " 0.2,\n", - " 14.8\n", - " ],\n", - " [\n", - " 0.18,\n", - " 14.82\n", - " ],\n", - " [\n", - " 0.17,\n", - " 14.82\n", - " ],\n", - " [\n", - " 0.15,\n", - " 14.84\n", - " ],\n", - " [\n", - " 0.14,\n", - " 14.84\n", - " ],\n", - " [\n", - " 0.12,\n", - " 14.86\n", - " ],\n", - " [\n", - " 0.11,\n", - " 14.86\n", - " ],\n", - " [\n", - " 0.1,\n", - " 14.87\n", - " ],\n", - " [\n", - " 0.1,\n", - " 17.78\n", - " ],\n", - " [\n", - " 0.11,\n", - " 17.79\n", - " ],\n", - " [\n", - " 0.12,\n", - " 17.79\n", - " ],\n", - " [\n", - " 0.14,\n", - " 17.81\n", - " ],\n", - " [\n", - " 0.15,\n", - " 17.81\n", - " ],\n", - " [\n", - " 0.17,\n", - " 17.83\n", - " ],\n", - " [\n", - " 0.18,\n", - " 17.83\n", - " ],\n", - " [\n", - " 0.2,\n", - " 17.85\n", - " ],\n", - " [\n", - " 0.21,\n", - " 17.85\n", - " ],\n", - " [\n", - " 0.23,\n", - " 17.87\n", - " ],\n", - " [\n", - " 0.24,\n", - " 17.87\n", - " ],\n", - " [\n", - " 0.26,\n", - " 17.89\n", - " ],\n", - " [\n", - " 0.27,\n", - " 17.89\n", - " ],\n", - " [\n", - " 0.29,\n", - " 17.91\n", - " ],\n", - " [\n", - " 0.3,\n", - " 17.91\n", - " ],\n", - " [\n", - " 0.32,\n", - " 17.93\n", - " ],\n", - " [\n", - " 0.33,\n", - " 17.93\n", - " ],\n", - " [\n", - " 0.35,\n", - " 17.95\n", - " ],\n", - " [\n", - " 0.36,\n", - " 17.95\n", - " ],\n", - " [\n", - " 0.38,\n", - " 17.97\n", - " ],\n", - " [\n", - " 0.39,\n", - " 17.97\n", - " ],\n", - " [\n", - " 0.41,\n", - " 17.99\n", - " ],\n", - " [\n", - " 0.42,\n", - " 17.99\n", - " ],\n", - " [\n", - " 0.44,\n", - " 18.01\n", - " ],\n", - " [\n", - " 0.45,\n", - " 18.01\n", - " ],\n", - " [\n", - " 0.47,\n", - " 18.03\n", - " ],\n", - " [\n", - " 0.48,\n", - " 18.03\n", - " ],\n", - " [\n", - " 0.5,\n", - " 18.05\n", - " ],\n", - " [\n", - " 0.51,\n", - " 18.05\n", - " ],\n", - " [\n", - " 0.53,\n", - " 18.07\n", - " ],\n", - " [\n", - " 0.54,\n", - " 18.07\n", - " ],\n", - " [\n", - " 0.56,\n", - " 18.09\n", - " ],\n", - " [\n", - " 0.57,\n", - " 18.09\n", - " ],\n", - " [\n", - " 0.59,\n", - " 18.11\n", - " ],\n", - " [\n", - " 0.6,\n", - " 18.11\n", - " ],\n", - " [\n", - " 0.62,\n", - " 18.13\n", - " ],\n", - " [\n", - " 0.63,\n", - " 18.13\n", - " ],\n", - " [\n", - " 0.65,\n", - " 18.15\n", - " ],\n", - " [\n", - " 0.66,\n", - " 18.15\n", - " ],\n", - " [\n", - " 0.68,\n", - " 18.17\n", - " ],\n", - " [\n", - " 0.69,\n", - " 18.17\n", - " ],\n", - " [\n", - " 0.71,\n", - " 18.19\n", - " ],\n", - " [\n", - " 0.72,\n", - " 18.19\n", - " ],\n", - " [\n", - " 0.74,\n", - " 18.21\n", - " ],\n", - " [\n", - " 0.75,\n", - " 18.21\n", - " ],\n", - " [\n", - " 0.77,\n", - " 18.23\n", - " ],\n", - " [\n", - " 0.78,\n", - " 18.23\n", - " ],\n", - " [\n", - " 0.8,\n", - " 18.25\n", - " ],\n", - " [\n", - " 0.81,\n", - " 18.25\n", - " ],\n", - " [\n", - " 0.83,\n", - " 18.27\n", - " ],\n", - " [\n", - " 0.84,\n", - " 18.27\n", - " ],\n", - " [\n", - " 0.86,\n", - " 18.29\n", - " ],\n", - " [\n", - " 0.87,\n", - " 18.29\n", - " ],\n", - " [\n", - " 0.89,\n", - " 18.31\n", - " ],\n", - " [\n", - " 0.9,\n", - " 18.31\n", - " ],\n", - " [\n", - " 0.91,\n", - " 18.32\n", - " ],\n", - " [\n", - " 0.92,\n", - " 18.32\n", - " ],\n", - " [\n", - " 0.94,\n", - " 18.34\n", - " ],\n", - " [\n", - " 0.95,\n", - " 18.34\n", - " ],\n", - " [\n", - " 0.97,\n", - " 18.36\n", - " ],\n", - " [\n", - " 0.98,\n", - " 18.36\n", - " ],\n", - " [\n", - " 1.0,\n", - " 18.38\n", - " ],\n", - " [\n", - " 1.01,\n", - " 18.38\n", - " ],\n", - " [\n", - " 1.03,\n", - " 18.4\n", - " ],\n", - " [\n", - " 1.04,\n", - " 18.4\n", - " ],\n", - " [\n", - " 1.06,\n", - " 18.42\n", - " ],\n", - " [\n", - " 1.07,\n", - " 18.42\n", - " ],\n", - " [\n", - " 1.09,\n", - " 18.44\n", - " ],\n", - " [\n", - " 1.1,\n", - " 18.44\n", - " ],\n", - " [\n", - " 1.12,\n", - " 18.46\n", - " ],\n", - " [\n", - " 1.13,\n", - " 18.46\n", - " ],\n", - " [\n", - " 1.15,\n", - " 18.48\n", - " ],\n", - " [\n", - " 1.16,\n", - " 18.48\n", - " ],\n", - " [\n", - " 1.18,\n", - " 18.5\n", - " ],\n", - " [\n", - " 1.19,\n", - " 18.5\n", - " ],\n", - " [\n", - " 1.21,\n", - " 18.52\n", - " ],\n", - " [\n", - " 1.22,\n", - " 18.52\n", - " ],\n", - " [\n", - " 1.24,\n", - " 18.54\n", - " ],\n", - " [\n", - " 1.25,\n", - " 18.54\n", - " ],\n", - " [\n", - " 1.27,\n", - " 18.56\n", - " ],\n", - " [\n", - " 1.28,\n", - " 18.56\n", - " ],\n", - " [\n", - " 1.3,\n", - " 18.58\n", - " ],\n", - " [\n", - " 1.31,\n", - " 18.58\n", - " ],\n", - " [\n", - " 1.33,\n", - " 18.6\n", - " ],\n", - " [\n", - " 1.34,\n", - " 18.6\n", - " ],\n", - " [\n", - " 1.36,\n", - " 18.62\n", - " ],\n", - " [\n", - " 1.37,\n", - " 18.62\n", - " ],\n", - " [\n", - " 1.39,\n", - " 18.64\n", - " ],\n", - " [\n", - " 1.4,\n", - " 18.64\n", - " ],\n", - " [\n", - " 1.42,\n", - " 18.66\n", - " ],\n", - " [\n", - " 1.43,\n", - " 18.66\n", - " ],\n", - " [\n", - " 1.45,\n", - " 18.68\n", - " ],\n", - " [\n", - " 1.46,\n", - " 18.68\n", - " ],\n", - " [\n", - " 1.48,\n", - " 18.7\n", - " ],\n", - " [\n", - " 1.49,\n", - " 18.7\n", - " ],\n", - " [\n", - " 1.51,\n", - " 18.72\n", - " ],\n", - " [\n", - " 1.52,\n", - " 18.72\n", - " ],\n", - " [\n", - " 1.54,\n", - " 18.74\n", - " ],\n", - " [\n", - " 1.55,\n", - " 18.74\n", - " ],\n", - " [\n", - " 1.57,\n", - " 18.76\n", - " ],\n", - " [\n", - " 1.58,\n", - " 18.76\n", - " ],\n", - " [\n", - " 1.6,\n", - " 18.78\n", - " ],\n", - " [\n", - " 1.61,\n", - " 18.78\n", - " ],\n", - " [\n", - " 1.63,\n", - " 18.8\n", - " ],\n", - " [\n", - " 1.64,\n", - " 18.8\n", - " ],\n", - " [\n", - " 1.66,\n", - " 18.82\n", - " ],\n", - " [\n", - " 1.67,\n", - " 18.82\n", - " ],\n", - " [\n", - " 1.69,\n", - " 18.84\n", - " ],\n", - " [\n", - " 1.7,\n", - " 18.84\n", - " ],\n", - " [\n", - " 1.72,\n", - " 18.86\n", - " ],\n", - " [\n", - " 1.73,\n", - " 18.86\n", - " ],\n", - " [\n", - " 1.75,\n", - " 18.88\n", - " ],\n", - " [\n", - " 1.76,\n", - " 18.88\n", - " ],\n", - " [\n", - " 1.78,\n", - " 18.9\n", - " ],\n", - " [\n", - " 1.79,\n", - " 18.9\n", - " ],\n", - " [\n", - " 1.81,\n", - " 18.92\n", - " ],\n", - " [\n", - " 1.82,\n", - " 18.92\n", - " ],\n", - " [\n", - " 1.84,\n", - " 18.94\n", - " ],\n", - " [\n", - " 1.85,\n", - " 18.94\n", - " ],\n", - " [\n", - " 1.87,\n", - " 18.96\n", - " ],\n", - " [\n", - " 8.8,\n", - " 18.96\n", - " ],\n", - " [\n", - " 8.8,\n", - " 15.38\n", - " ],\n", - " [\n", - " 8.81,\n", - " 15.37\n", - " ],\n", - " [\n", - " 8.81,\n", - " 13.69\n", - " ],\n", - " [\n", - " 8.32,\n", - " 13.69\n", - " ],\n", - " [\n", - " 8.31,\n", - " 13.68\n", - " ],\n", - " [\n", - " 8.31,\n", - " 13.58\n", - " ],\n", - " [\n", - " 8.32,\n", - " 13.57\n", - " ],\n", - " [\n", - " 8.81,\n", - " 13.57\n", - " ],\n", - " [\n", - " 8.81,\n", - " 10.87\n", - " ],\n", - " [\n", - " 8.82,\n", - " 10.86\n", - " ],\n", - " [\n", - " 8.92,\n", - " 10.86\n", - " ],\n", - " [\n", - " 8.93,\n", - " 10.87\n", - " ],\n", - " [\n", - " 8.93,\n", - " 11.07\n", - " ],\n", - " [\n", - " 11.61,\n", - " 11.07\n", - " ],\n", - " [\n", - " 11.62,\n", - " 11.08\n", - " ],\n", - " [\n", - " 11.62,\n", - " 11.18\n", - " ],\n", - " [\n", - " 11.61,\n", - " 11.19\n", - " ],\n", - " [\n", - " 8.93,\n", - " 11.19\n", - " ],\n", - " [\n", - " 8.93,\n", - " 15.26\n", - " ],\n", - " [\n", - " 12.71,\n", - " 15.26\n", - " ],\n", - " [\n", - " 12.71,\n", - " 11.19\n", - " ],\n", - " [\n", - " 12.47,\n", - " 11.19\n", - " ],\n", - " [\n", - " 12.46,\n", - " 11.18\n", - " ],\n", - " [\n", - " 12.46,\n", - " 11.08\n", - " ],\n", - " [\n", - " 12.47,\n", - " 11.07\n", - " ],\n", - " [\n", - " 13.04,\n", - " 11.07\n", - " ],\n", - " [\n", - " 13.05,\n", - " 11.08\n", - " ],\n", - " [\n", - " 13.05,\n", - " 11.18\n", - " ],\n", - " [\n", - " 13.04,\n", - " 11.19\n", - " ],\n", - " [\n", - " 12.83,\n", - " 11.19\n", - " ],\n", - " [\n", - " 12.83,\n", - " 15.26\n", - " ],\n", - " [\n", - " 15.81,\n", - " 15.26\n", - " ],\n", - " [\n", - " 15.81,\n", - " 11.19\n", - " ],\n", - " [\n", - " 13.9,\n", - " 11.19\n", - " ],\n", - " [\n", - " 13.89,\n", - " 11.18\n", - " ],\n", - " [\n", - " 13.89,\n", - " 11.08\n", - " ],\n", - " [\n", - " 13.9,\n", - " 11.07\n", - " ],\n", - " [\n", - " 15.99,\n", - " 11.07\n", - " ],\n", - " [\n", - " 16.0,\n", - " 11.08\n", - " ],\n", - " [\n", - " 16.0,\n", - " 11.18\n", - " ],\n", - " [\n", - " 15.99,\n", - " 11.19\n", - " ],\n", - " [\n", - " 15.93,\n", - " 11.19\n", - " ],\n", - " [\n", - " 15.93,\n", - " 15.26\n", - " ],\n", - " [\n", - " 18.9,\n", - " 15.26\n", - " ],\n", - " [\n", - " 18.9,\n", - " 11.19\n", - " ],\n", - " [\n", - " 16.85,\n", - " 11.19\n", - " ],\n", - " [\n", - " 16.84,\n", - " 11.18\n", - " ],\n", - " [\n", - " 16.84,\n", - " 11.08\n", - " ],\n", - " [\n", - " 16.85,\n", - " 11.07\n", - " ],\n", - " [\n", - " 17.01,\n", - " 11.07\n", - " ],\n", - " [\n", - " 17.01,\n", - " 11.03\n", - " ],\n", - " [\n", - " 17.02,\n", - " 11.02\n", - " ],\n", - " [\n", - " 17.12,\n", - " 11.02\n", - " ],\n", - " [\n", - " 17.13,\n", - " 11.03\n", - " ],\n", - " [\n", - " 17.13,\n", - " 11.07\n", - " ],\n", - " [\n", - " 18.9,\n", - " 11.07\n", - " ],\n", - " [\n", - " 18.9,\n", - " 9.69\n", - " ],\n", - " [\n", - " 17.13,\n", - " 9.69\n", - " ],\n", - " [\n", - " 17.13,\n", - " 10.17\n", - " ],\n", - " [\n", - " 17.12,\n", - " 10.18\n", - " ],\n", - " [\n", - " 17.02,\n", - " 10.18\n", - " ],\n", - " [\n", - " 17.01,\n", - " 10.17\n", - " ],\n", - " [\n", - " 17.01,\n", - " 9.69\n", - " ],\n", - " [\n", - " 16.85,\n", - " 9.69\n", - " ],\n", - " [\n", - " 16.84,\n", - " 9.68\n", - " ],\n", - " [\n", - " 16.84,\n", - " 9.58\n", - " ],\n", - " [\n", - " 16.85,\n", - " 9.57\n", - " ],\n", - " [\n", - " 18.9,\n", - " 9.57\n", - " ],\n", - " [\n", - " 18.9,\n", - " 5.79\n", - " ],\n", - " [\n", - " 15.73,\n", - " 5.79\n", - " ],\n", - " [\n", - " 15.73,\n", - " 9.57\n", - " ],\n", - " [\n", - " 15.99,\n", - " 9.57\n", - " ],\n", - " [\n", - " 16.0,\n", - " 9.58\n", - " ],\n", - " [\n", - " 16.0,\n", - " 9.68\n", - " ],\n", - " [\n", - " 15.99,\n", - " 9.69\n", - " ],\n", - " [\n", - " 15.38,\n", - " 9.69\n", - " ],\n", - " [\n", - " 15.37,\n", - " 9.68\n", - " ],\n", - " [\n", - " 15.37,\n", - " 9.58\n", - " ],\n", - " [\n", - " 15.38,\n", - " 9.57\n", - " ],\n", - " [\n", - " 15.61,\n", - " 9.57\n", - " ],\n", - " [\n", - " 15.61,\n", - " 7.49\n", - " ],\n", - " [\n", - " 12.85,\n", - " 7.49\n", - " ],\n", - " [\n", - " 12.84,\n", - " 7.48\n", - " ],\n", - " [\n", - " 12.84,\n", - " 7.39\n", - " ],\n", - " [\n", - " 12.85,\n", - " 7.38\n", - " ],\n", - " [\n", - " 15.61,\n", - " 7.38\n", - " ],\n", - " [\n", - " 15.61,\n", - " 5.79\n", - " ],\n", - " [\n", - " 14.17,\n", - " 5.79\n", - " ],\n", - " [\n", - " 14.16,\n", - " 5.78\n", - " ],\n", - " [\n", - " 14.16,\n", - " 5.68\n", - " ],\n", - " [\n", - " 14.17,\n", - " 5.67\n", - " ],\n", - " [\n", - " 18.9,\n", - " 5.67\n", - " ],\n", - " [\n", - " 18.9,\n", - " 0.1\n", - " ],\n", + "\n", + "console = genjax.pretty(show_locals=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25421e64-5f56-4cbe-aa20-104e7400d8e6", + "metadata": {}, + "outputs": [], + "source": [ + "intrinsics = b.Intrinsics(\n", + " height=1, width=50, fx=10.0, fy=1.0, cx=25.0, cy=0.0, near=0.01, far=20.0\n", + ")\n", + "b.setup_renderer(intrinsics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00c9ee51-8744-470b-aeac-59ccb04d0ca6", + "metadata": {}, + "outputs": [], + "source": [ + "b.setup_visualizer()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f959c4d7-f2c1-4449-8513-cdba7a4e036a", + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ + "env_data = {\n", + " \"env_name\": \"014e6d1cee6e6a1297a78f761fbc6700.json\",\n", + " \"paths\": [\n", " [\n", - " 13.24,\n", - " 0.1\n", - " ]\n", + " [15.13554761904762, 14.694047619047623],\n", + " [14.821214285714289, 14.604238095238099],\n", + " [14.551785714285716, 14.469523809523814],\n", + " [14.372166666666667, 14.245000000000005],\n", + " [14.237452380952382, 13.975571428571431],\n", + " [14.102738095238097, 13.706142857142861],\n", + " [14.057833333333337, 13.346904761904764],\n", + " [14.012928571428573, 13.077476190476194],\n", + " [13.878214285714288, 12.80804761904762],\n", + " [13.788404761904763, 12.359000000000002],\n", + " [13.608785714285714, 11.999761904761908],\n", + " [13.518976190476193, 11.685428571428574],\n", + " [13.47407142857143, 11.371095238095242],\n", + " [13.384261904761905, 11.056761904761908],\n", + " [13.24954761904762, 10.787333333333336],\n", + " [12.890309523809526, 10.60771428571429],\n", + " [12.665785714285716, 10.60771428571429],\n", + " [12.261642857142858, 10.562809523809527],\n", + " [11.8575, 10.517904761904765],\n", + " [11.408452380952381, 10.473000000000003],\n", + " [10.9145, 10.473000000000003],\n", + " [10.42054761904762, 10.473000000000003],\n", + " [10.061309523809525, 10.42809523809524],\n", + " [9.567357142857144, 10.473000000000003],\n", + " [9.118309523809524, 10.473000000000003],\n", + " [8.714166666666667, 10.473000000000003],\n", + " [8.175309523809524, 10.60771428571429],\n", + " [7.726261904761905, 10.87714285714286],\n", + " [7.367023809523809, 11.281285714285717],\n", + " [7.187404761904762, 11.730333333333336],\n", + " [7.052690476190477, 12.224285714285717],\n", + " [7.052690476190477, 12.673333333333336],\n", + " [7.097595238095238, 12.942761904761909],\n", + " [7.187404761904762, 13.302000000000003],\n", + " [7.2323095238095245, 13.706142857142861],\n", + " [7.097595238095238, 14.065380952380956],\n", + " [6.917976190476191, 14.42461904761905],\n", + " [6.558738095238095, 14.694047619047623],\n", + " [6.019880952380953, 14.918571428571433],\n", + " [5.660642857142857, 15.143095238095242],\n", + " [5.256500000000001, 15.367619047619051],\n", + " [5.076880952380952, 15.59214285714286],\n", + " [4.852357142857143, 15.81666666666667],\n", + " [4.582928571428571, 16.22080952380953],\n", + " [4.3134999999999994, 16.580047619047622],\n", + " [3.954261904761905, 16.75966666666667],\n", + " [3.550119047619047, 16.939285714285717],\n", + " [3.1459761904761905, 17.02909523809524],\n", + " [2.741833333333333, 17.02909523809524],\n", + " [2.2927857142857135, 17.074000000000005],\n", + " [1.8437380952380948, 16.98419047619048],\n", + " ],\n", + " [\n", + " [18.278880952380952, 14.514428571428574],\n", + " [17.91964285714286, 14.334809523809525],\n", + " [17.515500000000003, 14.110285714285716],\n", + " [17.290976190476194, 13.795952380952386],\n", + " [17.156261904761905, 13.571428571428573],\n", + " [16.931738095238096, 13.122380952380954],\n", + " [16.662309523809526, 12.583523809523811],\n", + " [16.572500000000005, 12.26919047619048],\n", + " [16.527595238095238, 11.909952380952383],\n", + " [16.437785714285717, 11.550714285714289],\n", + " [16.437785714285717, 11.191476190476193],\n", + " [16.303071428571428, 10.832238095238099],\n", + " [15.854023809523811, 10.562809523809527],\n", + " [15.539690476190477, 10.473000000000003],\n", + " [14.955928571428574, 10.42809523809524],\n", + " [14.41707142857143, 10.42809523809524],\n", + " [13.833309523809524, 10.338285714285718],\n", + " [13.339357142857144, 10.42809523809524],\n", + " [12.800500000000001, 10.473000000000003],\n", + " [12.306547619047619, 10.383190476190478],\n", + " [11.8575, 10.383190476190478],\n", + " [11.453357142857143, 10.42809523809524],\n", + " [10.9145, 10.42809523809524],\n", + " [10.510357142857144, 10.473000000000003],\n", + " [10.151119047619048, 10.42809523809524],\n", + " [9.746976190476191, 10.42809523809524],\n", + " [9.208119047619048, 10.473000000000003],\n", + " [8.53454761904762, 10.473000000000003],\n", + " [8.130404761904762, 10.742428571428574],\n", + " [7.636452380952381, 11.371095238095242],\n", + " [7.411928571428572, 11.865047619047623],\n", + " [7.187404761904762, 12.583523809523811],\n", + " [7.1425, 13.167285714285718],\n", + " [7.097595238095238, 13.706142857142861],\n", + " [7.052690476190477, 14.20009523809524],\n", + " [7.007785714285715, 14.604238095238099],\n", + " [6.648547619047619, 15.053285714285717],\n", + " ],\n", + " [\n", + " [14.821214285714289, 14.604238095238099],\n", + " [14.461976190476191, 14.42461904761905],\n", + " [14.147642857142857, 14.020476190476195],\n", + " [13.923119047619048, 13.481619047619052],\n", + " [13.743500000000003, 12.80804761904762],\n", + " [13.608785714285714, 12.179380952380955],\n", + " [13.518976190476193, 11.640523809523813],\n", + " [13.47407142857143, 11.191476190476193],\n", + " [13.384261904761905, 10.832238095238099],\n", + " [13.02502380952381, 10.562809523809527],\n", + " [12.261642857142858, 10.517904761904765],\n", + " [11.8575, 11.056761904761908],\n", + " [11.767690476190479, 11.82014285714286],\n", + " [11.722785714285715, 12.403904761904766],\n", + " [11.58807142857143, 13.122380952380954],\n", + " [11.408452380952381, 13.751047619047622],\n", + " [11.09411904761905, 14.334809523809525],\n", + " [10.555261904761906, 14.64914285714286],\n", + " ],\n", + " [],\n", + " ],\n", + " \"verts\": [\n", + " [13.24, 0.1],\n", + " [13.23, 0.11],\n", + " [13.23, 5.67],\n", + " [13.31, 5.67],\n", + " [13.32, 5.68],\n", + " [13.32, 5.78],\n", + " [13.31, 5.79],\n", + " [12.24, 5.79],\n", + " [12.23, 5.8],\n", + " [12.23, 9.57],\n", + " [14.52, 9.57],\n", + " [14.53, 9.58],\n", + " [14.53, 9.68],\n", + " [14.52, 9.69],\n", + " [8.93, 9.69],\n", + " [8.93, 9.85],\n", + " [8.92, 9.86],\n", + " [8.82, 9.86],\n", + " [8.81, 9.85],\n", + " [8.81, 9.69],\n", + " [5.98, 9.69],\n", + " [5.97, 9.7],\n", + " [5.96, 9.7],\n", + " [5.95, 9.69],\n", + " [5.63, 9.69],\n", + " [5.62, 9.68],\n", + " [5.62, 9.58],\n", + " [5.63, 9.57],\n", + " [5.87, 9.57],\n", + " [5.87, 5.79],\n", + " [1.94, 5.79],\n", + " [1.93, 5.8],\n", + " [1.93, 9.07],\n", + " [4.51, 9.07],\n", + " [4.53, 9.09],\n", + " [4.53, 9.57],\n", + " [4.77, 9.57],\n", + " [4.78, 9.58],\n", + " [4.78, 9.68],\n", + " [4.77, 9.69],\n", + " [4.53, 9.69],\n", + " [4.53, 10.14],\n", + " [4.52, 10.15],\n", + " [4.42, 10.15],\n", + " [4.41, 10.14],\n", + " [4.41, 9.19],\n", + " [1.93, 9.19],\n", + " [1.93, 11.07],\n", + " [4.41, 11.07],\n", + " [4.41, 11.0],\n", + " [4.42, 10.99],\n", + " [4.52, 10.99],\n", + " [4.53, 11.0],\n", + " [4.53, 11.07],\n", + " [5.32, 11.07],\n", + " [5.33, 11.06],\n", + " [5.33, 11.05],\n", + " [5.34, 11.04],\n", + " [5.36, 11.04],\n", + " [5.37, 11.05],\n", + " [5.38, 11.05],\n", + " [5.39, 11.06],\n", + " [5.4, 11.06],\n", + " [5.41, 11.07],\n", + " [5.42, 11.07],\n", + " [5.43, 11.08],\n", + " [5.44, 11.08],\n", + " [5.45, 11.09],\n", + " [5.43, 11.11],\n", + " [5.43, 11.18],\n", + " [5.42, 11.19],\n", + " [5.42, 13.57],\n", + " [6.46, 13.57],\n", + " [6.47, 13.58],\n", + " [6.47, 13.68],\n", + " [6.46, 13.69],\n", + " [4.19, 13.69],\n", + " [4.18, 13.68],\n", + " [4.18, 13.58],\n", + " [4.19, 13.57],\n", + " [5.31, 13.57],\n", + " [5.31, 11.19],\n", + " [1.93, 11.19],\n", + " [1.93, 13.57],\n", + " [3.23, 13.57],\n", + " [3.24, 13.58],\n", + " [3.24, 13.68],\n", + " [3.23, 13.69],\n", + " [1.89, 13.69],\n", + " [1.88, 13.7],\n", + " [1.87, 13.7],\n", + " [1.85, 13.72],\n", + " [1.84, 13.72],\n", + " [1.82, 13.74],\n", + " [1.81, 13.74],\n", + " [1.79, 13.76],\n", + " [1.78, 13.76],\n", + " [1.76, 13.78],\n", + " [1.75, 13.78],\n", + " [1.73, 13.8],\n", + " [1.72, 13.8],\n", + " [1.71, 13.81],\n", + " [1.7, 13.81],\n", + " [1.68, 13.83],\n", + " [1.67, 13.83],\n", + " [1.65, 13.85],\n", + " [1.64, 13.85],\n", + " [1.62, 13.87],\n", + " [1.61, 13.87],\n", + " [1.59, 13.89],\n", + " [1.58, 13.89],\n", + " [1.56, 13.91],\n", + " [1.55, 13.91],\n", + " [1.53, 13.93],\n", + " [1.52, 13.93],\n", + " [1.5, 13.95],\n", + " [1.49, 13.95],\n", + " [1.47, 13.97],\n", + " [1.46, 13.97],\n", + " [1.44, 13.99],\n", + " [1.43, 13.99],\n", + " [1.41, 14.01],\n", + " [1.4, 14.01],\n", + " [1.38, 14.03],\n", + " [1.37, 14.03],\n", + " [1.35, 14.05],\n", + " [1.34, 14.05],\n", + " [1.32, 14.07],\n", + " [1.31, 14.07],\n", + " [1.29, 14.09],\n", + " [1.28, 14.09],\n", + " [1.27, 14.1],\n", + " [1.26, 14.1],\n", + " [1.24, 14.12],\n", + " [1.23, 14.12],\n", + " [1.21, 14.14],\n", + " [1.2, 14.14],\n", + " [1.18, 14.16],\n", + " [1.17, 14.16],\n", + " [1.15, 14.18],\n", + " [1.14, 14.18],\n", + " [1.12, 14.2],\n", + " [1.11, 14.2],\n", + " [1.09, 14.22],\n", + " [1.08, 14.22],\n", + " [1.06, 14.24],\n", + " [1.05, 14.24],\n", + " [1.03, 14.26],\n", + " [1.02, 14.26],\n", + " [1.0, 14.28],\n", + " [0.99, 14.28],\n", + " [0.97, 14.3],\n", + " [0.96, 14.3],\n", + " [0.94, 14.32],\n", + " [0.93, 14.32],\n", + " [0.91, 14.34],\n", + " [0.9, 14.34],\n", + " [0.88, 14.36],\n", + " [0.87, 14.36],\n", + " [0.85, 14.38],\n", + " [0.84, 14.38],\n", + " [0.83, 14.39],\n", + " [0.82, 14.39],\n", + " [0.8, 14.41],\n", + " [0.79, 14.41],\n", + " [0.77, 14.43],\n", + " [0.76, 14.43],\n", + " [0.74, 14.45],\n", + " [0.73, 14.45],\n", + " [0.71, 14.47],\n", + " [0.7, 14.47],\n", + " [0.68, 14.49],\n", + " [0.67, 14.49],\n", + " [0.65, 14.51],\n", + " [0.64, 14.51],\n", + " [0.62, 14.53],\n", + " [0.61, 14.53],\n", + " [0.59, 14.55],\n", + " [0.58, 14.55],\n", + " [0.56, 14.57],\n", + " [0.55, 14.57],\n", + " [0.53, 14.59],\n", + " [0.52, 14.59],\n", + " [0.5, 14.61],\n", + " [0.49, 14.61],\n", + " [0.47, 14.63],\n", + " [0.46, 14.63],\n", + " [0.44, 14.65],\n", + " [0.43, 14.65],\n", + " [0.41, 14.67],\n", + " [0.4, 14.67],\n", + " [0.39, 14.68],\n", + " [0.38, 14.68],\n", + " [0.36, 14.7],\n", + " [0.35, 14.7],\n", + " [0.33, 14.72],\n", + " [0.32, 14.72],\n", + " [0.3, 14.74],\n", + " [0.29, 14.74],\n", + " [0.27, 14.76],\n", + " [0.26, 14.76],\n", + " [0.24, 14.78],\n", + " [0.23, 14.78],\n", + " [0.21, 14.8],\n", + " [0.2, 14.8],\n", + " [0.18, 14.82],\n", + " [0.17, 14.82],\n", + " [0.15, 14.84],\n", + " [0.14, 14.84],\n", + " [0.12, 14.86],\n", + " [0.11, 14.86],\n", + " [0.1, 14.87],\n", + " [0.1, 17.78],\n", + " [0.11, 17.79],\n", + " [0.12, 17.79],\n", + " [0.14, 17.81],\n", + " [0.15, 17.81],\n", + " [0.17, 17.83],\n", + " [0.18, 17.83],\n", + " [0.2, 17.85],\n", + " [0.21, 17.85],\n", + " [0.23, 17.87],\n", + " [0.24, 17.87],\n", + " [0.26, 17.89],\n", + " [0.27, 17.89],\n", + " [0.29, 17.91],\n", + " [0.3, 17.91],\n", + " [0.32, 17.93],\n", + " [0.33, 17.93],\n", + " [0.35, 17.95],\n", + " [0.36, 17.95],\n", + " [0.38, 17.97],\n", + " [0.39, 17.97],\n", + " [0.41, 17.99],\n", + " [0.42, 17.99],\n", + " [0.44, 18.01],\n", + " [0.45, 18.01],\n", + " [0.47, 18.03],\n", + " [0.48, 18.03],\n", + " [0.5, 18.05],\n", + " [0.51, 18.05],\n", + " [0.53, 18.07],\n", + " [0.54, 18.07],\n", + " [0.56, 18.09],\n", + " [0.57, 18.09],\n", + " [0.59, 18.11],\n", + " [0.6, 18.11],\n", + " [0.62, 18.13],\n", + " [0.63, 18.13],\n", + " [0.65, 18.15],\n", + " [0.66, 18.15],\n", + " [0.68, 18.17],\n", + " [0.69, 18.17],\n", + " [0.71, 18.19],\n", + " [0.72, 18.19],\n", + " [0.74, 18.21],\n", + " [0.75, 18.21],\n", + " [0.77, 18.23],\n", + " [0.78, 18.23],\n", + " [0.8, 18.25],\n", + " [0.81, 18.25],\n", + " [0.83, 18.27],\n", + " [0.84, 18.27],\n", + " [0.86, 18.29],\n", + " [0.87, 18.29],\n", + " [0.89, 18.31],\n", + " [0.9, 18.31],\n", + " [0.91, 18.32],\n", + " [0.92, 18.32],\n", + " [0.94, 18.34],\n", + " [0.95, 18.34],\n", + " [0.97, 18.36],\n", + " [0.98, 18.36],\n", + " [1.0, 18.38],\n", + " [1.01, 18.38],\n", + " [1.03, 18.4],\n", + " [1.04, 18.4],\n", + " [1.06, 18.42],\n", + " [1.07, 18.42],\n", + " [1.09, 18.44],\n", + " [1.1, 18.44],\n", + " [1.12, 18.46],\n", + " [1.13, 18.46],\n", + " [1.15, 18.48],\n", + " [1.16, 18.48],\n", + " [1.18, 18.5],\n", + " [1.19, 18.5],\n", + " [1.21, 18.52],\n", + " [1.22, 18.52],\n", + " [1.24, 18.54],\n", + " [1.25, 18.54],\n", + " [1.27, 18.56],\n", + " [1.28, 18.56],\n", + " [1.3, 18.58],\n", + " [1.31, 18.58],\n", + " [1.33, 18.6],\n", + " [1.34, 18.6],\n", + " [1.36, 18.62],\n", + " [1.37, 18.62],\n", + " [1.39, 18.64],\n", + " [1.4, 18.64],\n", + " [1.42, 18.66],\n", + " [1.43, 18.66],\n", + " [1.45, 18.68],\n", + " [1.46, 18.68],\n", + " [1.48, 18.7],\n", + " [1.49, 18.7],\n", + " [1.51, 18.72],\n", + " [1.52, 18.72],\n", + " [1.54, 18.74],\n", + " [1.55, 18.74],\n", + " [1.57, 18.76],\n", + " [1.58, 18.76],\n", + " [1.6, 18.78],\n", + " [1.61, 18.78],\n", + " [1.63, 18.8],\n", + " [1.64, 18.8],\n", + " [1.66, 18.82],\n", + " [1.67, 18.82],\n", + " [1.69, 18.84],\n", + " [1.7, 18.84],\n", + " [1.72, 18.86],\n", + " [1.73, 18.86],\n", + " [1.75, 18.88],\n", + " [1.76, 18.88],\n", + " [1.78, 18.9],\n", + " [1.79, 18.9],\n", + " [1.81, 18.92],\n", + " [1.82, 18.92],\n", + " [1.84, 18.94],\n", + " [1.85, 18.94],\n", + " [1.87, 18.96],\n", + " [8.8, 18.96],\n", + " [8.8, 15.38],\n", + " [8.81, 15.37],\n", + " [8.81, 13.69],\n", + " [8.32, 13.69],\n", + " [8.31, 13.68],\n", + " [8.31, 13.58],\n", + " [8.32, 13.57],\n", + " [8.81, 13.57],\n", + " [8.81, 10.87],\n", + " [8.82, 10.86],\n", + " [8.92, 10.86],\n", + " [8.93, 10.87],\n", + " [8.93, 11.07],\n", + " [11.61, 11.07],\n", + " [11.62, 11.08],\n", + " [11.62, 11.18],\n", + " [11.61, 11.19],\n", + " [8.93, 11.19],\n", + " [8.93, 15.26],\n", + " [12.71, 15.26],\n", + " [12.71, 11.19],\n", + " [12.47, 11.19],\n", + " [12.46, 11.18],\n", + " [12.46, 11.08],\n", + " [12.47, 11.07],\n", + " [13.04, 11.07],\n", + " [13.05, 11.08],\n", + " [13.05, 11.18],\n", + " [13.04, 11.19],\n", + " [12.83, 11.19],\n", + " [12.83, 15.26],\n", + " [15.81, 15.26],\n", + " [15.81, 11.19],\n", + " [13.9, 11.19],\n", + " [13.89, 11.18],\n", + " [13.89, 11.08],\n", + " [13.9, 11.07],\n", + " [15.99, 11.07],\n", + " [16.0, 11.08],\n", + " [16.0, 11.18],\n", + " [15.99, 11.19],\n", + " [15.93, 11.19],\n", + " [15.93, 15.26],\n", + " [18.9, 15.26],\n", + " [18.9, 11.19],\n", + " [16.85, 11.19],\n", + " [16.84, 11.18],\n", + " [16.84, 11.08],\n", + " [16.85, 11.07],\n", + " [17.01, 11.07],\n", + " [17.01, 11.03],\n", + " [17.02, 11.02],\n", + " [17.12, 11.02],\n", + " [17.13, 11.03],\n", + " [17.13, 11.07],\n", + " [18.9, 11.07],\n", + " [18.9, 9.69],\n", + " [17.13, 9.69],\n", + " [17.13, 10.17],\n", + " [17.12, 10.18],\n", + " [17.02, 10.18],\n", + " [17.01, 10.17],\n", + " [17.01, 9.69],\n", + " [16.85, 9.69],\n", + " [16.84, 9.68],\n", + " [16.84, 9.58],\n", + " [16.85, 9.57],\n", + " [18.9, 9.57],\n", + " [18.9, 5.79],\n", + " [15.73, 5.79],\n", + " [15.73, 9.57],\n", + " [15.99, 9.57],\n", + " [16.0, 9.58],\n", + " [16.0, 9.68],\n", + " [15.99, 9.69],\n", + " [15.38, 9.69],\n", + " [15.37, 9.68],\n", + " [15.37, 9.58],\n", + " [15.38, 9.57],\n", + " [15.61, 9.57],\n", + " [15.61, 7.49],\n", + " [12.85, 7.49],\n", + " [12.84, 7.48],\n", + " [12.84, 7.39],\n", + " [12.85, 7.38],\n", + " [15.61, 7.38],\n", + " [15.61, 5.79],\n", + " [14.17, 5.79],\n", + " [14.16, 5.78],\n", + " [14.16, 5.68],\n", + " [14.17, 5.67],\n", + " [18.9, 5.67],\n", + " [18.9, 0.1],\n", + " [13.24, 0.1],\n", " ],\n", " \"clutter_verts\": [\n", " [\n", - " [\n", - " 7.1517380952380964,\n", - " 17.62209523809524\n", - " ],\n", - " [\n", - " 7.851738095238096,\n", - " 17.62209523809524\n", - " ],\n", - " [\n", - " 7.851738095238096,\n", - " 18.322095238095244\n", - " ],\n", - " [\n", - " 7.1517380952380964,\n", - " 18.322095238095244\n", - " ],\n", - " [\n", - " 7.1517380952380964,\n", - " 17.62209523809524\n", - " ]\n", + " [7.1517380952380964, 17.62209523809524],\n", + " [7.851738095238096, 17.62209523809524],\n", + " [7.851738095238096, 18.322095238095244],\n", + " [7.1517380952380964, 18.322095238095244],\n", + " [7.1517380952380964, 17.62209523809524],\n", " ],\n", " [\n", - " [\n", - " 6.253642857142857,\n", - " 17.62209523809524\n", - " ],\n", - " [\n", - " 6.953642857142857,\n", - " 17.62209523809524\n", - " ],\n", - " [\n", - " 6.953642857142857,\n", - " 18.322095238095244\n", - " ],\n", - " [\n", - " 6.253642857142857,\n", - " 18.322095238095244\n", - " ],\n", - " [\n", - " 6.253642857142857,\n", - " 17.62209523809524\n", - " ]\n", + " [6.253642857142857, 17.62209523809524],\n", + " [6.953642857142857, 17.62209523809524],\n", + " [6.953642857142857, 18.322095238095244],\n", + " [6.253642857142857, 18.322095238095244],\n", + " [6.253642857142857, 17.62209523809524],\n", " ],\n", " [\n", - " [\n", - " 6.702690476190478,\n", - " 16.813809523809525\n", - " ],\n", - " [\n", - " 7.402690476190477,\n", - " 16.813809523809525\n", - " ],\n", - " [\n", - " 7.402690476190477,\n", - " 17.513809523809527\n", - " ],\n", - " [\n", - " 6.702690476190478,\n", - " 17.513809523809527\n", - " ],\n", - " [\n", - " 6.702690476190478,\n", - " 16.813809523809525\n", - " ]\n", + " [6.702690476190478, 16.813809523809525],\n", + " [7.402690476190477, 16.813809523809525],\n", + " [7.402690476190477, 17.513809523809527],\n", + " [6.702690476190478, 17.513809523809527],\n", + " [6.702690476190478, 16.813809523809525],\n", " ],\n", " [\n", - " [\n", - " 7.46607142857143,\n", - " 16.768904761904764\n", - " ],\n", - " [\n", - " 8.16607142857143,\n", - " 16.768904761904764\n", - " ],\n", - " [\n", - " 8.16607142857143,\n", - " 17.468904761904767\n", - " ],\n", - " [\n", - " 7.46607142857143,\n", - " 17.468904761904767\n", - " ],\n", - " [\n", - " 7.46607142857143,\n", - " 16.768904761904764\n", - " ]\n", + " [7.46607142857143, 16.768904761904764],\n", + " [8.16607142857143, 16.768904761904764],\n", + " [8.16607142857143, 17.468904761904767],\n", + " [7.46607142857143, 17.468904761904767],\n", + " [7.46607142857143, 16.768904761904764],\n", " ],\n", " [\n", - " [\n", - " 5.35554761904762,\n", - " 17.62209523809524\n", - " ],\n", - " [\n", - " 6.055547619047619,\n", - " 17.62209523809524\n", - " ],\n", - " [\n", - " 6.055547619047619,\n", - " 18.322095238095244\n", - " ],\n", - " [\n", - " 5.35554761904762,\n", - " 18.322095238095244\n", - " ],\n", - " [\n", - " 5.35554761904762,\n", - " 17.62209523809524\n", - " ]\n", + " [5.35554761904762, 17.62209523809524],\n", + " [6.055547619047619, 17.62209523809524],\n", + " [6.055547619047619, 18.322095238095244],\n", + " [5.35554761904762, 18.322095238095244],\n", + " [5.35554761904762, 17.62209523809524],\n", " ],\n", " [\n", - " [\n", - " 5.669880952380954,\n", - " 16.768904761904764\n", - " ],\n", - " [\n", - " 6.369880952380953,\n", - " 16.768904761904764\n", - " ],\n", - " [\n", - " 6.369880952380953,\n", - " 17.468904761904767\n", - " ],\n", - " [\n", - " 5.669880952380954,\n", - " 17.468904761904767\n", - " ],\n", - " [\n", - " 5.669880952380954,\n", - " 16.768904761904764\n", - " ]\n", + " [5.669880952380954, 16.768904761904764],\n", + " [6.369880952380953, 16.768904761904764],\n", + " [6.369880952380953, 17.468904761904767],\n", + " [5.669880952380954, 17.468904761904767],\n", + " [5.669880952380954, 16.768904761904764],\n", " ],\n", " [\n", - " [\n", - " 0.505833333333333,\n", - " 15.107428571428576\n", - " ],\n", - " [\n", - " 1.205833333333333,\n", - " 15.107428571428576\n", - " ],\n", - " [\n", - " 1.205833333333333,\n", - " 15.807428571428575\n", - " ],\n", - " [\n", - " 0.505833333333333,\n", - " 15.807428571428575\n", - " ],\n", - " [\n", - " 0.505833333333333,\n", - " 15.107428571428576\n", - " ]\n", + " [0.505833333333333, 15.107428571428576],\n", + " [1.205833333333333, 15.107428571428576],\n", + " [1.205833333333333, 15.807428571428575],\n", + " [0.505833333333333, 15.807428571428575],\n", + " [0.505833333333333, 15.107428571428576],\n", " ],\n", " [\n", - " [\n", - " 1.4488333333333325,\n", - " 15.062523809523812\n", - " ],\n", - " [\n", - " 2.1488333333333327,\n", - " 15.062523809523812\n", - " ],\n", - " [\n", - " 2.1488333333333327,\n", - " 15.762523809523811\n", - " ],\n", - " [\n", - " 1.4488333333333325,\n", - " 15.762523809523811\n", - " ],\n", - " [\n", - " 1.4488333333333325,\n", - " 15.062523809523812\n", - " ]\n", + " [1.4488333333333325, 15.062523809523812],\n", + " [2.1488333333333327, 15.062523809523812],\n", + " [2.1488333333333327, 15.762523809523811],\n", + " [1.4488333333333325, 15.762523809523811],\n", + " [1.4488333333333325, 15.062523809523812],\n", " ],\n", " [\n", - " [\n", - " 1.538642857142856,\n", - " 14.164428571428575\n", - " ],\n", - " [\n", - " 2.2386428571428563,\n", - " 14.164428571428575\n", - " ],\n", - " [\n", - " 2.2386428571428563,\n", - " 14.864428571428574\n", - " ],\n", - " [\n", - " 1.538642857142856,\n", - " 14.864428571428574\n", - " ],\n", - " [\n", - " 1.538642857142856,\n", - " 14.164428571428575\n", - " ]\n", + " [1.538642857142856, 14.164428571428575],\n", + " [2.2386428571428563, 14.164428571428575],\n", + " [2.2386428571428563, 14.864428571428574],\n", + " [1.538642857142856, 14.864428571428574],\n", + " [1.538642857142856, 14.164428571428575],\n", " ],\n", " [\n", - " [\n", - " 2.391833333333333,\n", - " 14.02971428571429\n", - " ],\n", - " [\n", - " 3.091833333333333,\n", - " 14.02971428571429\n", - " ],\n", - " [\n", - " 3.091833333333333,\n", - " 14.729714285714289\n", - " ],\n", - " [\n", - " 2.391833333333333,\n", - " 14.729714285714289\n", - " ],\n", - " [\n", - " 2.391833333333333,\n", - " 14.02971428571429\n", - " ]\n", + " [2.391833333333333, 14.02971428571429],\n", + " [3.091833333333333, 14.02971428571429],\n", + " [3.091833333333333, 14.729714285714289],\n", + " [2.391833333333333, 14.729714285714289],\n", + " [2.391833333333333, 14.02971428571429],\n", " ],\n", " [\n", - " [\n", - " 2.4367380952380944,\n", - " 15.017619047619052\n", - " ],\n", - " [\n", - " 3.1367380952380945,\n", - " 15.017619047619052\n", - " ],\n", - " [\n", - " 3.1367380952380945,\n", - " 15.717619047619051\n", - " ],\n", - " [\n", - " 2.4367380952380944,\n", - " 15.717619047619051\n", - " ],\n", - " [\n", - " 2.4367380952380944,\n", - " 15.017619047619052\n", - " ]\n", + " [2.4367380952380944, 15.017619047619052],\n", + " [3.1367380952380945, 15.017619047619052],\n", + " [3.1367380952380945, 15.717619047619051],\n", + " [2.4367380952380944, 15.717619047619051],\n", + " [2.4367380952380944, 15.017619047619052],\n", " ],\n", " [\n", - " [\n", - " 9.307166666666667,\n", - " 11.425238095238099\n", - " ],\n", - " [\n", - " 10.007166666666667,\n", - " 11.425238095238099\n", - " ],\n", - " [\n", - " 10.007166666666667,\n", - " 12.125238095238098\n", - " ],\n", - " [\n", - " 9.307166666666667,\n", - " 12.125238095238098\n", - " ],\n", - " [\n", - " 9.307166666666667,\n", - " 11.425238095238099\n", - " ]\n", + " [9.307166666666667, 11.425238095238099],\n", + " [10.007166666666667, 11.425238095238099],\n", + " [10.007166666666667, 12.125238095238098],\n", + " [9.307166666666667, 12.125238095238098],\n", + " [9.307166666666667, 11.425238095238099],\n", " ],\n", " [\n", - " [\n", - " 9.307166666666667,\n", - " 12.233523809523811\n", - " ],\n", - " [\n", - " 10.007166666666667,\n", - " 12.233523809523811\n", - " ],\n", - " [\n", - " 10.007166666666667,\n", - " 12.93352380952381\n", - " ],\n", - " [\n", - " 9.307166666666667,\n", - " 12.93352380952381\n", - " ],\n", - " [\n", - " 9.307166666666667,\n", - " 12.233523809523811\n", - " ]\n", + " [9.307166666666667, 12.233523809523811],\n", + " [10.007166666666667, 12.233523809523811],\n", + " [10.007166666666667, 12.93352380952381],\n", + " [9.307166666666667, 12.93352380952381],\n", + " [9.307166666666667, 12.233523809523811],\n", " ],\n", " [\n", - " [\n", - " 9.35207142857143,\n", - " 12.907095238095243\n", - " ],\n", - " [\n", - " 10.052071428571429,\n", - " 12.907095238095243\n", - " ],\n", - " [\n", - " 10.052071428571429,\n", - " 13.607095238095242\n", - " ],\n", - " [\n", - " 9.35207142857143,\n", - " 13.607095238095242\n", - " ],\n", - " [\n", - " 9.35207142857143,\n", - " 12.907095238095243\n", - " ]\n", + " [9.35207142857143, 12.907095238095243],\n", + " [10.052071428571429, 12.907095238095243],\n", + " [10.052071428571429, 13.607095238095242],\n", + " [9.35207142857143, 13.607095238095242],\n", + " [9.35207142857143, 12.907095238095243],\n", " ],\n", " [\n", - " [\n", - " 16.536833333333334,\n", - " 4.779333333333335\n", - " ],\n", - " [\n", - " 17.236833333333337,\n", - " 4.779333333333335\n", - " ],\n", - " [\n", - " 17.236833333333337,\n", - " 5.479333333333335\n", - " ],\n", - " [\n", - " 16.536833333333334,\n", - " 5.479333333333335\n", - " ],\n", - " [\n", - " 16.536833333333334,\n", - " 4.779333333333335\n", - " ]\n", + " [16.536833333333334, 4.779333333333335],\n", + " [17.236833333333337, 4.779333333333335],\n", + " [17.236833333333337, 5.479333333333335],\n", + " [16.536833333333334, 5.479333333333335],\n", + " [16.536833333333334, 4.779333333333335],\n", " ],\n", " [\n", - " [\n", - " 17.479833333333332,\n", - " 4.779333333333335\n", - " ],\n", - " [\n", - " 18.179833333333335,\n", - " 4.779333333333335\n", - " ],\n", - " [\n", - " 18.179833333333335,\n", - " 5.479333333333335\n", - " ],\n", - " [\n", - " 17.479833333333332,\n", - " 5.479333333333335\n", - " ],\n", - " [\n", - " 17.479833333333332,\n", - " 4.779333333333335\n", - " ]\n", + " [17.479833333333332, 4.779333333333335],\n", + " [18.179833333333335, 4.779333333333335],\n", + " [18.179833333333335, 5.479333333333335],\n", + " [17.479833333333332, 5.479333333333335],\n", + " [17.479833333333332, 4.779333333333335],\n", " ],\n", " [\n", - " [\n", - " 17.479833333333332,\n", - " 3.8812380952380976\n", - " ],\n", - " [\n", - " 18.179833333333335,\n", - " 3.8812380952380976\n", - " ],\n", - " [\n", - " 18.179833333333335,\n", - " 4.581238095238097\n", - " ],\n", - " [\n", - " 17.479833333333332,\n", - " 4.581238095238097\n", - " ],\n", - " [\n", - " 17.479833333333332,\n", - " 3.8812380952380976\n", - " ]\n", + " [17.479833333333332, 3.8812380952380976],\n", + " [18.179833333333335, 3.8812380952380976],\n", + " [18.179833333333335, 4.581238095238097],\n", + " [17.479833333333332, 4.581238095238097],\n", + " [17.479833333333332, 3.8812380952380976],\n", " ],\n", " [\n", - " [\n", - " 16.671547619047622,\n", - " 3.8363333333333354\n", - " ],\n", - " [\n", - " 17.371547619047625,\n", - " 3.8363333333333354\n", - " ],\n", - " [\n", - " 17.371547619047625,\n", - " 4.536333333333335\n", - " ],\n", - " [\n", - " 16.671547619047622,\n", - " 4.536333333333335\n", - " ],\n", - " [\n", - " 16.671547619047622,\n", - " 3.8363333333333354\n", - " ]\n", + " [16.671547619047622, 3.8363333333333354],\n", + " [17.371547619047625, 3.8363333333333354],\n", + " [17.371547619047625, 4.536333333333335],\n", + " [16.671547619047622, 4.536333333333335],\n", + " [16.671547619047622, 3.8363333333333354],\n", " ],\n", " [\n", - " [\n", - " 17.5247380952381,\n", - " 2.983142857142859\n", - " ],\n", - " [\n", - " 18.224738095238102,\n", - " 2.983142857142859\n", - " ],\n", - " [\n", - " 18.224738095238102,\n", - " 3.683142857142859\n", - " ],\n", - " [\n", - " 17.5247380952381,\n", - " 3.683142857142859\n", - " ],\n", - " [\n", - " 17.5247380952381,\n", - " 2.983142857142859\n", - " ]\n", + " [17.5247380952381, 2.983142857142859],\n", + " [18.224738095238102, 2.983142857142859],\n", + " [18.224738095238102, 3.683142857142859],\n", + " [17.5247380952381, 3.683142857142859],\n", + " [17.5247380952381, 2.983142857142859],\n", " ],\n", " [\n", - " [\n", - " 17.5247380952381,\n", - " 2.0850476190476206\n", - " ],\n", - " [\n", - " 18.224738095238102,\n", - " 2.0850476190476206\n", - " ],\n", - " [\n", - " 18.224738095238102,\n", - " 2.7850476190476208\n", - " ],\n", - " [\n", - " 17.5247380952381,\n", - " 2.7850476190476208\n", - " ],\n", - " [\n", - " 17.5247380952381,\n", - " 2.0850476190476206\n", - " ]\n", - " ]\n", - " ]\n", + " [17.5247380952381, 2.0850476190476206],\n", + " [18.224738095238102, 2.0850476190476206],\n", + " [18.224738095238102, 2.7850476190476208],\n", + " [17.5247380952381, 2.7850476190476208],\n", + " [17.5247380952381, 2.0850476190476206],\n", + " ],\n", + " ],\n", "}" ] }, @@ -2662,9 +760,9 @@ "# print(data[\"paths\"]);\n", "# print(data[\"verts\"]);\n", "points = jnp.array(env_data[\"verts\"])\n", - "plt.plot(points[:,0],points[:,1])\n", + "plt.plot(points[:, 0], points[:, 1])\n", "paths = jnp.array(env_data[\"paths\"][0])\n", - "plt.plot(paths[:,0], paths[:,1])" + "plt.plot(paths[:, 0], paths[:, 1])" ] }, { @@ -2677,10 +775,14 @@ "pieces = []\n", "for i in range(len(points) - 1):\n", " point_1 = points[i]\n", - " point_2 = points[i+1]\n", - " p = b.t3d.transform_from_pos(jnp.concatenate([(point_1 + point_2) /2 , jnp.array([0.0])]))\n", + " point_2 = points[i + 1]\n", + " p = b.t3d.transform_from_pos(\n", + " jnp.concatenate([(point_1 + point_2) / 2, jnp.array([0.0])])\n", + " )\n", "\n", - " dimensions = np.array(jnp.concatenate([jnp.abs(point_1 - point_2) , jnp.array([1.0])]))\n", + " dimensions = np.array(\n", + " jnp.concatenate([jnp.abs(point_1 - point_2), jnp.array([1.0])])\n", + " )\n", " piece = trimesh.creation.box(dimensions, p)\n", "\n", " # print(\"==============================\")\n", @@ -2702,7 +804,10 @@ "metadata": {}, "outputs": [], "source": [ - "pose_from_position = lambda p: b.t3d.transform_from_rot_and_pos(b.t3d.rotation_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), -jnp.pi/2), jnp.concatenate([p, jnp.array([0.0])]))\n", + "pose_from_position = lambda p: b.t3d.transform_from_rot_and_pos(\n", + " b.t3d.rotation_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), -jnp.pi / 2),\n", + " jnp.concatenate([p, jnp.array([0.0])]),\n", + ")\n", "pose_sequence = jax.vmap(pose_from_position)(paths)" ] }, @@ -2717,7 +822,7 @@ "b.RENDERER.add_mesh(map_mesh, center_mesh=False)\n", "b.clear()\n", "b.show_trimesh(\"1\", map_mesh)\n", - "for (i,p) in enumerate(pose_sequence):\n", + "for i, p in enumerate(pose_sequence):\n", " b.show_pose(f\"2_{i}\", p)" ] }, @@ -2728,7 +833,9 @@ "metadata": {}, "outputs": [], "source": [ - "img = b.RENDERER.render(b.t3d.inverse_pose(pose_sequence[10])[None,...], jnp.array([0]))" + "img = b.RENDERER.render(\n", + " b.t3d.inverse_pose(pose_sequence[10])[None, ...], jnp.array([0])\n", + ")" ] }, { @@ -2738,8 +845,10 @@ "metadata": {}, "outputs": [], "source": [ - "cloud = b.t3d.unproject_depth_jit(img[:,:,2], intrinsics)\n", - "b.show_cloud(\"reproject\", b.t3d.apply_transform(cloud, pose_sequence[10]).reshape(-1,3))" + "cloud = b.t3d.unproject_depth_jit(img[:, :, 2], intrinsics)\n", + "b.show_cloud(\n", + " \"reproject\", b.t3d.apply_transform(cloud, pose_sequence[10]).reshape(-1, 3)\n", + ")" ] }, { @@ -2751,26 +860,33 @@ "source": [ "@genjax.gen\n", "def slam_single_frame():\n", - " agent_pose = b.genjax.uniform_pose(\n", - " jnp.ones(3) * -30.0,\n", - " jnp.ones(3) * 30.0,\n", - " ) @ \"agent_pose\"\n", + " agent_pose = (\n", + " b.genjax.uniform_pose(\n", + " jnp.ones(3) * -30.0,\n", + " jnp.ones(3) * 30.0,\n", + " )\n", + " @ \"agent_pose\"\n", + " )\n", "\n", - " rendered = b.RENDERER.render(\n", - " jnp.linalg.inv(agent_pose)[None,...], jnp.array([0])\n", - " )[...,:3]\n", + " rendered = b.RENDERER.render(jnp.linalg.inv(agent_pose)[None, ...], jnp.array([0]))[\n", + " ..., :3\n", + " ]\n", " # rendered = jnp.ones((intrinsics.height, intrinsics.width, 3)) * agent_pose[0,3]\n", - " \n", + "\n", " variance = genjax.distributions.tfp_uniform(0.0001, 0.1) @ \"variance\"\n", - " outlier_prob = genjax.distributions.tfp_uniform(0.0001, 0.1) @ \"outlier_prob\"\n", + " outlier_prob = genjax.distributions.tfp_uniform(0.0001, 0.1) @ \"outlier_prob\"\n", " image = b.genjax.image_likelihood(rendered, variance, outlier_prob, 10.0) @ \"image\"\n", " return agent_pose, rendered, image\n", "\n", + "\n", "def viz_trace(trace):\n", " b.clear()\n", " b.show_trimesh(\"1\", map_mesh)\n", " b.show_pose(f\"pose\", trace[\"agent_pose\"])\n", - " b.show_cloud(\"reproject\", b.t3d.apply_transform(trace[\"image\"], trace[\"agent_pose\"]).reshape(-1,3))" + " b.show_cloud(\n", + " \"reproject\",\n", + " b.t3d.apply_transform(trace[\"image\"], trace[\"agent_pose\"]).reshape(-1, 3),\n", + " )" ] }, { @@ -2783,11 +899,11 @@ "def enumerate_pose(trace, key, pose):\n", " return trace.update(\n", " key,\n", - " genjax.choice_map({\n", - " \"agent_pose\": pose\n", - " }),\n", + " genjax.choice_map({\"agent_pose\": pose}),\n", " tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),\n", " )[1][2].get_score()\n", + "\n", + "\n", "enumerate_pose_vmap = jax.vmap(enumerate_pose, in_axes=(None, None, 0))" ] }, @@ -2809,13 +925,17 @@ "outputs": [], "source": [ "key = jax.random.PRNGKey(10)\n", - "(key, (_, trace))= slam_single_frame.importance(key,\n", - " genjax.choice_map({\n", - " \"agent_pose\": pose_sequence[2],\n", - " \"variance\": 0.01,\n", - " \"outlier_prob\": 0.01,\n", - " \n", - " }), ())\n", + "(key, (_, trace)) = slam_single_frame.importance(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"agent_pose\": pose_sequence[2],\n", + " \"variance\": 0.01,\n", + " \"outlier_prob\": 0.01,\n", + " }\n", + " ),\n", + " (),\n", + ")\n", "print(trace.get_score())\n", "viz_trace(trace)" ] @@ -2828,7 +948,9 @@ "outputs": [], "source": [ "%%time\n", - "scores = enumerator.enumerate_choices_get_scores(trace, key, jnp.tile(pose_sequence[2][None,...], (50000,1,1))) \n", + "scores = enumerator.enumerate_choices_get_scores(\n", + " trace, key, jnp.tile(pose_sequence[2][None, ...], (50000, 1, 1))\n", + ")\n", "print(scores)" ] }, @@ -2893,7 +1015,7 @@ "metadata": {}, "outputs": [], "source": [ - "grid.reshape(100, 100, -1)[-1,0]" + "grid.reshape(100, 100, -1)[-1, 0]" ] }, { diff --git a/scripts/experiments/slam/slam_with_room_obj.ipynb b/scripts/experiments/slam/slam_with_room_obj.ipynb index 5f1e5ce2..97a41ce9 100644 --- a/scripts/experiments/slam/slam_with_room_obj.ipynb +++ b/scripts/experiments/slam/slam_with_room_obj.ipynb @@ -39,16 +39,12 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=75.0, fy=75.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.001, far=16.0\n", + " height=100, width=100, fx=75.0, fy=75.0, cx=50.0, cy=50.0, near=0.001, far=16.0\n", ")\n", "from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer\n", - "jax_renderer = JaxRenderer(intrinsics)\n" + "\n", + "jax_renderer = JaxRenderer(intrinsics)" ] }, { @@ -68,8 +64,9 @@ } ], "source": [ - "\n", "import trimesh\n", + "\n", + "\n", "def as_mesh(scene_or_mesh):\n", " \"\"\"\n", " Convert a possible scene to a mesh.\n", @@ -82,18 +79,23 @@ " else:\n", " # we lose texture information here\n", " mesh = trimesh.util.concatenate(\n", - " tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)\n", - " for g in scene_or_mesh.geometry.values()))\n", + " tuple(\n", + " trimesh.Trimesh(vertices=g.vertices, faces=g.faces)\n", + " for g in scene_or_mesh.geometry.values()\n", + " )\n", + " )\n", " else:\n", - " assert(isinstance(mesh, trimesh.Trimesh))\n", + " assert isinstance(mesh, trimesh.Trimesh)\n", " mesh = scene_or_mesh\n", " return mesh\n", - "mesh =as_mesh(trimesh.load('InteriorTest.obj'))\n", - "mesh.vertices = mesh.vertices * jnp.array([1.0, -1.0, 1.0]) + jnp.array([0.0, 1.0, 0.0])\n", + "\n", + "\n", + "mesh = as_mesh(trimesh.load(\"InteriorTest.obj\"))\n", + "mesh.vertices = mesh.vertices * jnp.array([1.0, -1.0, 1.0]) + jnp.array([0.0, 1.0, 0.0])\n", "vertices = mesh.vertices\n", "faces = mesh.faces\n", "print(b.utils.aabb(mesh.vertices))\n", - "b.show_trimesh(\"1\",mesh)" + "b.show_trimesh(\"1\", mesh)" ] }, { @@ -110,10 +112,22 @@ " elif move == \"left\":\n", " return pose @ b.transform_from_pos(jnp.array([-0.1, 0.0, 0.0]))\n", " elif move == \"rotate_left\":\n", - " return pose @ b.transform_from_axis_angle(jnp.array([-0.0, 1.0, 0.0]), jnp.deg2rad(10.0) * -1.0)\n", + " return pose @ b.transform_from_axis_angle(\n", + " jnp.array([-0.0, 1.0, 0.0]), jnp.deg2rad(10.0) * -1.0\n", + " )\n", " elif move == \"rotate_right\":\n", - " return pose @ b.transform_from_axis_angle(jnp.array([-0.0, 1.0, 0.0]), jnp.deg2rad(10.0) * 1.0)\n", - "moves = [*[\"ahead\" for _ in range(20)], *[\"rotate_left\" for _ in range(8)], *[\"ahead\" for _ in range(10)], *[\"rotate_right\" for _ in range(4)],*[\"ahead\" for _ in range(10)]]\n", + " return pose @ b.transform_from_axis_angle(\n", + " jnp.array([-0.0, 1.0, 0.0]), jnp.deg2rad(10.0) * 1.0\n", + " )\n", + "\n", + "\n", + "moves = [\n", + " *[\"ahead\" for _ in range(20)],\n", + " *[\"rotate_left\" for _ in range(8)],\n", + " *[\"ahead\" for _ in range(10)],\n", + " *[\"rotate_right\" for _ in range(4)],\n", + " *[\"ahead\" for _ in range(10)],\n", + "]\n", "camera_poses = [jnp.eye(4)]\n", "for move in moves:\n", " camera_poses.append(apply_move(camera_poses[-1], move))\n", @@ -127,9 +141,9 @@ "outputs": [], "source": [ "gt_images = [\n", - " jax_renderer.render(vertices, faces, b.inverse_pose(p), intrinsics)[0][0,...]\n", + " jax_renderer.render(vertices, faces, b.inverse_pose(p), intrinsics)[0][0, ...]\n", " for p in camera_poses\n", - "]\n" + "]" ] }, { @@ -138,7 +152,10 @@ "metadata": {}, "outputs": [], "source": [ - "b.make_gif_from_pil_images([b.get_depth_image(img, min_val=0.0, remove_max=False) for img in gt_images], \"gt.gif\")" + "b.make_gif_from_pil_images(\n", + " [b.get_depth_image(img, min_val=0.0, remove_max=False) for img in gt_images],\n", + " \"gt.gif\",\n", + ")" ] }, { @@ -149,10 +166,21 @@ "source": [ "def loss(trans, q, gt_img):\n", " camera_pose = b.translation_and_quaternion_to_pose_matrix(trans, q)\n", - " img = jax_renderer.render(vertices, faces, b.inverse_pose(camera_pose), intrinsics)[0][0,...]\n", + " img = jax_renderer.render(vertices, faces, b.inverse_pose(camera_pose), intrinsics)[\n", + " 0\n", + " ][0, ...]\n", " return (jnp.abs(img - gt_img)).mean()\n", "\n", - "value_and_grad_jit = jax.jit(jax.value_and_grad(loss, argnums=(0,1,)))" + "\n", + "value_and_grad_jit = jax.jit(\n", + " jax.value_and_grad(\n", + " loss,\n", + " argnums=(\n", + " 0,\n", + " 1,\n", + " ),\n", + " )\n", + ")" ] }, { @@ -163,8 +191,8 @@ "source": [ "b.clear()\n", "b.show_pose(\"actual\", camera_poses[1])\n", - "tr,q = b.pose_matrix_to_translation_and_quaternion(camera_poses[0])\n", - "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)" + "tr, q = b.pose_matrix_to_translation_and_quaternion(camera_poses[0])\n", + "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr, q), size=0.1)" ] }, { @@ -188,17 +216,17 @@ } ], "source": [ - "print(\"start \" , value_and_grad_jit(tr, q, gt_images[1]))\n", + "print(\"start \", value_and_grad_jit(tr, q, gt_images[1]))\n", "poses = []\n", "pbar = tqdm(range(200))\n", "timestep = 1\n", - "for _ in pbar:\n", + "for _ in pbar:\n", " loss, (g1, g2) = value_and_grad_jit(tr, q, gt_images[timestep])\n", " tr -= g1 * 0.01\n", " q -= g2 * 0.01\n", " pbar.set_description(f\"{loss}\")\n", " # poses.append(b.translation_and_quaternion_to_pose_matrix(tr,q))\n", - "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)" + "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr, q), size=0.1)" ] }, { @@ -277,20 +305,22 @@ "source": [ "b.clear()\n", "b.show_pose(\"actual\", camera_poses[1])\n", - "tr,q = b.pose_matrix_to_translation_and_quaternion(camera_poses[0])\n", - "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)\n", + "tr, q = b.pose_matrix_to_translation_and_quaternion(camera_poses[0])\n", + "b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr, q), size=0.1)\n", "inferred_poses = []\n", "pbar2 = tqdm(range(len(gt_images)))\n", - "for timestep in pbar2:\n", + "for timestep in pbar2:\n", " pbar = tqdm(range(50))\n", - " for _ in pbar:\n", + " for _ in pbar:\n", " loss, (g1, g2) = value_and_grad_jit(tr, q, gt_images[timestep])\n", " tr -= g1 * 0.01\n", " q -= g2 * 0.01\n", " pbar.set_description(f\"{loss}\")\n", " b.show_pose(\"actual\", camera_poses[timestep])\n", - " b.show_pose(\"inferred\", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)\n", - " inferred_poses.append(b.translation_and_quaternion_to_pose_matrix(tr,q))\n", + " b.show_pose(\n", + " \"inferred\", b.translation_and_quaternion_to_pose_matrix(tr, q), size=0.1\n", + " )\n", + " inferred_poses.append(b.translation_and_quaternion_to_pose_matrix(tr, q))\n", "inferred_poses = jnp.stack(inferred_poses)" ] }, @@ -321,28 +351,56 @@ "source": [ "buffs = []\n", "for i in tqdm(range(len(inferred_poses))):\n", - "# for i in tqdm(range(1)):\n", - " fig = plt.figure(figsize=(8,12))\n", + " # for i in tqdm(range(1)):\n", + " fig = plt.figure(figsize=(8, 12))\n", " fig.suptitle(\"Localization using Differentiable Rendering\", fontsize=21)\n", " # fig = plt.figure()\n", " gs = gridspec.GridSpec(3, 2)\n", "\n", - " ax = fig.add_subplot(gs[0,0])\n", - " ax.imshow(np.array(b.get_depth_image(gt_images[i], min_val=0.0, max_val=5.0, remove_max=False)))\n", + " ax = fig.add_subplot(gs[0, 0])\n", + " ax.imshow(\n", + " np.array(\n", + " b.get_depth_image(gt_images[i], min_val=0.0, max_val=5.0, remove_max=False)\n", + " )\n", + " )\n", " ax.set_title(\"Observed Depth Image\")\n", " ax.axis(\"off\")\n", - " ax = fig.add_subplot(gs[0,1])\n", + " ax = fig.add_subplot(gs[0, 1])\n", " camera_pose_inferred = inferred_poses[i]\n", - " img = jax_renderer.render(vertices, faces, b.inverse_pose(camera_pose_inferred), intrinsics)[0][0,...]\n", - " ax.imshow(np.array(b.get_depth_image(gt_images[i], min_val=0.0, max_val=5.0, remove_max=False)))\n", + " img = jax_renderer.render(\n", + " vertices, faces, b.inverse_pose(camera_pose_inferred), intrinsics\n", + " )[0][0, ...]\n", + " ax.imshow(\n", + " np.array(\n", + " b.get_depth_image(gt_images[i], min_val=0.0, max_val=5.0, remove_max=False)\n", + " )\n", + " )\n", " ax.set_title(\"Depth Image at Inferred Pose\")\n", " ax.axis(\"off\")\n", "\n", - " ax = fig.add_subplot(gs[1:,:])\n", + " ax = fig.add_subplot(gs[1:, :])\n", " length = 0.2\n", - " ax.quiver(camera_poses[:,0,3], camera_poses[:,2,3], camera_poses[:,0,2]*length, camera_poses[:,2,2]*length, angles=\"xy\", color=\"black\", label=\"Ground Truth Poses\", alpha=0.1)\n", + " ax.quiver(\n", + " camera_poses[:, 0, 3],\n", + " camera_poses[:, 2, 3],\n", + " camera_poses[:, 0, 2] * length,\n", + " camera_poses[:, 2, 2] * length,\n", + " angles=\"xy\",\n", + " color=\"black\",\n", + " label=\"Ground Truth Poses\",\n", + " alpha=0.1,\n", + " )\n", " length = 0.05\n", - " ax.quiver(inferred_poses[i,0,3], inferred_poses[i,2,3], inferred_poses[i,0,2]*length, inferred_poses[i,2,2]*length, angles=\"xy\", color=\"red\", label=\"Inferred Pose\", alpha=0.4)\n", + " ax.quiver(\n", + " inferred_poses[i, 0, 3],\n", + " inferred_poses[i, 2, 3],\n", + " inferred_poses[i, 0, 2] * length,\n", + " inferred_poses[i, 2, 2] * length,\n", + " angles=\"xy\",\n", + " color=\"red\",\n", + " label=\"Inferred Pose\",\n", + " alpha=0.4,\n", + " )\n", " ax.set_xlim(-2.0, 1.5)\n", " ax.set_ylim(0.0, 3.5)\n", " ax.set_aspect(\"equal\")\n", @@ -427,10 +485,21 @@ "source": [ "b.make_gif_from_pil_images(buffs, \"inferred.gif\")\n", "import subprocess\n", + "\n", "fps = 3.0\n", "for i in range(len(buffs)):\n", " buffs[i].convert(\"RGB\").save(\"%07d.png\" % i)\n", - "subprocess.call([\"ffmpeg\",\"-y\",\"-r\",str(fps),\"-i\", \"%07d.png\",\"localization_with_gradients.mp4\"])" + "subprocess.call(\n", + " [\n", + " \"ffmpeg\",\n", + " \"-y\",\n", + " \"-r\",\n", + " str(fps),\n", + " \"-i\",\n", + " \"%07d.png\",\n", + " \"localization_with_gradients.mp4\",\n", + " ]\n", + ")" ] }, { diff --git a/scripts/experiments/tabletop/analysis.ipynb b/scripts/experiments/tabletop/analysis.ipynb index 7c933c22..be5e4d39 100644 --- a/scripts/experiments/tabletop/analysis.ipynb +++ b/scripts/experiments/tabletop/analysis.ipynb @@ -36,21 +36,22 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=500.0, fy=500.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=20.0\n", + " height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0\n", ")\n", "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", "\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -109,12 +110,12 @@ " # print(pred_set[scene_id][\"variance\"])\n", " # print(pred_set[scene_id][\"outlier_prob\"])\n", " if set(pred_ids) == set(gt_ids):\n", - " correct +=1\n", + " correct += 1\n", " else:\n", " wrong_prediction.append(scene_id)\n", " print(gt_ids, pred_ids)\n", " continue\n", - " print(correct) " + " print(correct)" ] }, { @@ -144,20 +145,28 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "importance_jit = jax.jit(b.model.importance)\n", "\n", - "contact_enumerators = [b.make_enumerator([f\"contact_params_{i}\", \"variance\", \"outlier_prob\"]) for i in range(5)]\n", + "contact_enumerators = [\n", + " b.make_enumerator([f\"contact_params_{i}\", \"variance\", \"outlier_prob\"])\n", + " for i in range(5)\n", + "]\n", "add_object_jit = jax.jit(b.add_object)\n", "\n", - "def c2f_contact_update(trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):\n", + "\n", + "def c2f_contact_update(\n", + " trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID\n", + "):\n", " contact_param_grid = contact_param_deltas + trace_[f\"contact_params_{number}\"]\n", - " scores = contact_enumerators[number].enumerate_choices_get_scores(trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)\n", - " i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " scores = contact_enumerators[number].enumerate_choices_get_scores(\n", + " trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID\n", + " )\n", + " i, j, k = jnp.unravel_index(scores.argmax(), scores.shape)\n", " return contact_enumerators[number].update_choices(\n", - " trace_, key,\n", - " contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", + " trace_, key, contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", " )\n", + "\n", + "\n", "c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=(\"number\",))" ] }, @@ -173,18 +182,20 @@ "OUTLIER_GRID = jnp.array([0.00001, 0.0001, 0.001])\n", "\n", "grid_params = [\n", - " (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi, (11,11,11)),\n", - " (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.05, 0.0, (21,21,1))\n", + " (0.3, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi, (11, 11, 11)),\n", + " (0.1, jnp.pi, (11, 11, 11)),\n", + " (0.05, jnp.pi / 3, (11, 11, 11)),\n", + " (0.02, jnp.pi, (5, 5, 51)),\n", + " (0.01, jnp.pi / 5, (11, 11, 11)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.05, 0.0, (21, 21, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]\n", - "key = jax.random.PRNGKey(500)\n" + "key = jax.random.PRNGKey(500)" ] }, { @@ -212,7 +223,10 @@ " V_GRID = VARIANCE_GRID\n", " O_GRID = OUTLIER_GRID\n", "else:\n", - " V_GRID, O_GRID = jnp.array([VARIANCE_GRID[V_VARIANT]]), jnp.array([OUTLIER_GRID[O_VARIANT]])\n", + " V_GRID, O_GRID = (\n", + " jnp.array([VARIANCE_GRID[V_VARIANT]]),\n", + " jnp.array([OUTLIER_GRID[O_VARIANT]]),\n", + " )\n", "\n", "print(V_GRID, O_GRID)\n", "\n", @@ -220,7 +234,9 @@ "print(b.genjax.get_indices(gt_trace))\n", "b.genjax.viz_trace_meshcat(gt_trace)\n", "choices = gt_trace.get_choices()\n", - "key, (_,trace) = importance_jit(key, choices, (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:-1], 100.0))\n", + "key, (_, trace) = importance_jit(\n", + " key, choices, (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:-1], 100.0)\n", + ")\n", "print(trace.get_score())" ] }, @@ -239,28 +255,38 @@ " )\n", ")\n", "\n", - "weight, gt_trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"parent_2\": 0,\n", - " \"parent_3\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " \"face_parent_1\": 2,\n", - " \"face_parent_2\": 2,\n", - " \"face_parent_3\": 2,\n", - " \"face_child_1\": 3,\n", - " \"face_child_2\": 3,\n", - " \"face_child_3\": 3,\n", - " \"variance\": 0.0001,\n", - " \"outlier_prob\": 0.0001,\n", - "}), (\n", - " jnp.arange(4),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, OUTLIER_VOLUME, 1.0)\n", + "weight, gt_trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"parent_2\": 0,\n", + " \"parent_3\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " \"face_parent_1\": 2,\n", + " \"face_parent_2\": 2,\n", + " \"face_parent_3\": 2,\n", + " \"face_child_1\": 3,\n", + " \"face_child_2\": 3,\n", + " \"face_child_3\": 3,\n", + " \"variance\": 0.0001,\n", + " \"outlier_prob\": 0.0001,\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(4),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [jnp.array([-0.2, -0.2, -2 * jnp.pi]), jnp.array([0.2, 0.2, 2 * jnp.pi])]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " OUTLIER_VOLUME,\n", + " 1.0,\n", + " ),\n", ")\n", "print(gt_trace.get_score())" ] @@ -272,7 +298,9 @@ "metadata": {}, "outputs": [], "source": [ - "_,trace = importance_jit(key, trace.get_choices(), (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:]))" + "_, trace = importance_jit(\n", + " key, trace.get_choices(), (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:])\n", + ")" ] }, { @@ -315,23 +343,27 @@ "all_all_paths = []\n", "for _ in range(3):\n", " all_paths = []\n", - " for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)):\n", + " for obj_id in tqdm(range(len(b.RENDERER.meshes) - 1)):\n", " path = []\n", - " trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)\n", + " trace_ = add_object_jit(trace, key, obj_id, 0, 2, 3)\n", " number = b.get_contact_params(trace_).shape[0] - 1\n", " path.append(trace_)\n", " for c2f_iter in range(len(contact_param_gridding_schedule)):\n", - " trace_ = c2f_contact_update_jit(trace_, key, number,\n", - " contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)\n", + " trace_ = c2f_contact_update_jit(\n", + " trace_,\n", + " key,\n", + " number,\n", + " contact_param_gridding_schedule[c2f_iter],\n", + " V_GRID,\n", + " O_GRID,\n", + " )\n", " path.append(trace_)\n", " # for c2f_iter in range(len(contact_param_gridding_schedule)):\n", " # trace_ = c2f_contact_update_jit(trace_, key, number,\n", " # contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID)\n", - " all_paths.append(\n", - " path\n", - " )\n", + " all_paths.append(path)\n", " all_all_paths.append(all_paths)\n", - " \n", + "\n", " scores = jnp.array([t[-1].get_score() for t in all_paths])\n", " print(scores)\n", " normalized_scores = b.utils.normalize_log_scores(scores)\n", @@ -356,28 +388,30 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", "all_all_paths = []\n", "for _ in range(3):\n", " all_paths = []\n", - " for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)):\n", + " for obj_id in tqdm(range(len(b.RENDERER.meshes) - 1)):\n", " path = []\n", - " trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)\n", + " trace_ = add_object_jit(trace, key, obj_id, 0, 2, 3)\n", " number = b.genjax.get_contact_params(trace_).shape[0] - 1\n", " path.append(trace_)\n", " for c2f_iter in range(len(contact_param_gridding_schedule)):\n", - " trace_ = c2f_contact_update_jit(trace_, key, number,\n", - " contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)\n", + " trace_ = c2f_contact_update_jit(\n", + " trace_,\n", + " key,\n", + " number,\n", + " contact_param_gridding_schedule[c2f_iter],\n", + " V_GRID,\n", + " O_GRID,\n", + " )\n", " path.append(trace_)\n", " # for c2f_iter in range(len(contact_param_gridding_schedule)):\n", " # trace_ = c2f_contact_update_jit(trace_, key, number,\n", " # contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID)\n", - " all_paths.append(\n", - " path\n", - " )\n", + " all_paths.append(path)\n", " all_all_paths.append(all_paths)\n", - " \n", + "\n", " scores = jnp.array([t[-1].get_score() for t in all_paths])\n", " print(scores)\n", " normalized_scores = b.utils.normalize_log_scores(scores)\n", @@ -404,7 +438,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.genjax.print_trace(trace)\n" + "b.genjax.print_trace(trace)" ] }, { diff --git a/scripts/experiments/tabletop/mug.ipynb b/scripts/experiments/tabletop/mug.ipynb index c39c0ced..35ee5cf5 100644 --- a/scripts/experiments/tabletop/mug.ipynb +++ b/scripts/experiments/tabletop/mug.ipynb @@ -36,20 +36,21 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=50.0, fy=50.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=20.0\n", + " height=100, width=100, fx=50.0, fy=50.0, cx=50.0, cy=50.0, near=0.01, far=20.0\n", ")\n", "\n", "b.setup_renderer(intrinsics)\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "meshes = []\n", - "for idx in range(1,22):\n", - " mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", - " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)\n", - "b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"), scaling_factor=1.0/1000000000.0)\n" + "for idx in range(1, 22):\n", + " mesh_path = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\"\n", + " )\n", + " b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0)\n", + "b.RENDERER.add_mesh_from_file(\n", + " os.path.join(b.utils.get_assets_dir(), \"sample_objs/cube.obj\"),\n", + " scaling_factor=1.0 / 1000000000.0,\n", + ")" ] }, { @@ -61,7 +62,7 @@ "source": [ "table_pose = b.t3d.inverse_pose(\n", " b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.2, .05]),\n", + " jnp.array([0.0, 0.2, 0.05]),\n", " jnp.array([0.0, 0.0, 0.0]),\n", " jnp.array([0.0, 0.0, 1.0]),\n", " )\n", @@ -110,17 +111,26 @@ "metadata": {}, "outputs": [], "source": [ - "contact_enumerators = [b.genjax.make_enumerator([f\"contact_params_{i}\", \"variance\", \"outlier_prob\"]) for i in range(5)]\n", + "contact_enumerators = [\n", + " b.genjax.make_enumerator([f\"contact_params_{i}\", \"variance\", \"outlier_prob\"])\n", + " for i in range(5)\n", + "]\n", "single_enumerators = b.genjax.make_enumerator([f\"contact_params_1\"])\n", "\n", - "def c2f_contact_update(trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):\n", + "\n", + "def c2f_contact_update(\n", + " trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID\n", + "):\n", " contact_param_grid = contact_param_deltas + trace_[f\"contact_params_{number}\"]\n", - " scores = contact_enumerators[number].enumerate_choices_get_scores(trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)\n", - " i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)\n", + " scores = contact_enumerators[number].enumerate_choices_get_scores(\n", + " trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID\n", + " )\n", + " i, j, k = jnp.unravel_index(scores.argmax(), scores.shape)\n", " return contact_enumerators[number].update_choices(\n", - " trace_, key,\n", - " contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", + " trace_, key, contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", " )\n", + "\n", + "\n", "c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=(\"number\",))" ] }, @@ -146,30 +156,30 @@ "OUTLIER_VOLUME = 10.0\n", "\n", "grid_params = [\n", - " (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi/2, (11,11,11)), (0.1, jnp.pi/2, (11,11,11)),\n", - " (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.01, 0.0, (21,21,1))\n", + " (0.3, jnp.pi, (11, 11, 11)),\n", + " (0.2, jnp.pi / 2, (11, 11, 11)),\n", + " (0.1, jnp.pi / 2, (11, 11, 11)),\n", + " (0.05, jnp.pi / 3, (11, 11, 11)),\n", + " (0.02, jnp.pi, (5, 5, 51)),\n", + " (0.01, jnp.pi / 5, (11, 11, 11)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", + " (0.01, 0.0, (21, 21, 1)),\n", "]\n", "contact_param_gridding_schedule = [\n", - " b.utils.make_translation_grid_enumeration_3d(\n", - " -x, -x, -ang,\n", - " x, x, ang,\n", - " *nums\n", - " )\n", - " for (x,ang,nums) in grid_params\n", + " b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums)\n", + " for (x, ang, nums) in grid_params\n", "]\n", "\n", "width = 0.04\n", "ang = jnp.pi\n", "final_contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(\n", - " -width, -width, -ang,\n", - " width, width, ang,\n", - " 21,21,300\n", + " -width, -width, -ang, width, width, ang, 21, 21, 300\n", ")\n", "\n", + "\n", "def get_depth_image(image):\n", " mval = image[image < image.max()].max()\n", - " return b.get_depth_image(image, max=mval)\n", - "\n" + " return b.get_depth_image(image, max=mval)" ] }, { @@ -181,87 +191,117 @@ "source": [ "for experiment_iteration in tqdm(range(2)):\n", " print(key)\n", - " key = jax.random.split(key,1)[0]\n", - " \n", - " weight, gt_trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"id_1\": jnp.int32(13),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " # \"contact_params_1\": jnp.array([ 0.01630328 ,-0.06595182, -2.946241 ]),\n", - " \"face_parent_1\": 2,\n", - " \"face_child_1\": 3,\n", - " \"variance\": VARIANCE_GRID[0],\n", - " \"outlier_prob\": OUTLIER_GRID[0],\n", - " }), (\n", - " jnp.arange(2),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.1, -0.1, -2*jnp.pi]), jnp.array([0.1, 0.1, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, OUTLIER_VOLUME)\n", - " )\n", + " key = jax.random.split(key, 1)[0]\n", + "\n", + " weight, gt_trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"id_1\": jnp.int32(13),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " # \"contact_params_1\": jnp.array([ 0.01630328 ,-0.06595182, -2.946241 ]),\n", + " \"face_parent_1\": 2,\n", + " \"face_child_1\": 3,\n", + " \"variance\": VARIANCE_GRID[0],\n", + " \"outlier_prob\": OUTLIER_GRID[0],\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(2),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [\n", + " jnp.array([-0.1, -0.1, -2 * jnp.pi]),\n", + " jnp.array([0.1, 0.1, 2 * jnp.pi]),\n", + " ]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " OUTLIER_VOLUME,\n", + " ),\n", + " )\n", " trace = gt_trace\n", " contact_param_grid = final_contact_param_deltas + trace[f\"contact_params_1\"]\n", - " weights = jnp.concatenate([\n", + " weights = jnp.concatenate(\n", + " [\n", " contact_enumerators[1].enumerate_choices_get_scores(\n", - " trace,\n", - " key,\n", - " d + trace[f\"contact_params_1\"],\n", - " VARIANCE_GRID,\n", - " OUTLIER_GRID\n", - " ) for d in jnp.array_split(final_contact_param_deltas, 55)\n", - " ],axis=0\n", + " trace, key, d + trace[f\"contact_params_1\"], VARIANCE_GRID, OUTLIER_GRID\n", + " )\n", + " for d in jnp.array_split(final_contact_param_deltas, 55)\n", + " ],\n", + " axis=0,\n", " )\n", - " \n", - " i,j,k = jnp.unravel_index(weights.argmax(), weights.shape)\n", - " trace= contact_enumerators[1].update_choices(\n", - " trace, key,\n", - " contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", + "\n", + " i, j, k = jnp.unravel_index(weights.argmax(), weights.shape)\n", + " trace = contact_enumerators[1].update_choices(\n", + " trace, key, contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", " )\n", " print(trace[\"contact_params_1\"])\n", " print(trace[\"variance\"])\n", - " \n", + "\n", " print(trace.get_score())\n", - " \n", + "\n", " print(gt_trace[\"contact_params_1\"])\n", " print(gt_trace[\"variance\"])\n", " print(gt_trace.get_score())\n", - " \n", + "\n", " key2 = jax.random.PRNGKey(0)\n", " sampled_indices = jax.random.categorical(key2, weights.reshape(-1), shape=(1000,))\n", " sampled_indices = jnp.unravel_index(sampled_indices, weights.shape)[0]\n", " sampled_params = contact_param_grid[sampled_indices]\n", " actual_params = gt_trace[\"contact_params_1\"]\n", - " \n", + "\n", " fig = plt.figure(constrained_layout=True)\n", " widths = [1, 1]\n", " heights = [2]\n", - " spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", - " \n", + " spec = fig.add_gridspec(\n", + " ncols=2, nrows=1, width_ratios=widths, height_ratios=heights\n", + " )\n", + "\n", " ax = fig.add_subplot(spec[0, 0])\n", - " ax.imshow(jnp.array(get_depth_image(trace[\"image\"][...,2])))\n", + " ax.imshow(jnp.array(get_depth_image(trace[\"image\"][..., 2])))\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", " ax.set_title(\"Observed Depth\")\n", - " \n", - " \n", + "\n", " ax = fig.add_subplot(spec[0, 1])\n", " ax.set_aspect(1.0)\n", - " circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle=\"--\", linewidth=0.5)\n", + " circ = plt.Circle(\n", + " (0, 0),\n", + " radius=1,\n", + " edgecolor=\"black\",\n", + " facecolor=\"None\",\n", + " linestyle=\"--\",\n", + " linewidth=0.5,\n", + " )\n", " ax.add_patch(circ)\n", " ax.set_xlim(-2.0, 2.0)\n", " ax.set_ylim(-2.0, 2.0)\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.scatter(-jnp.sin(sampled_params[:,2]),-jnp.cos(sampled_params[:,2]),label=\"Posterior Samples\", alpha=0.5, s=15)\n", - " ax.scatter(-jnp.sin(actual_params[2]),-jnp.cos(actual_params[2]), color=(1.0, 0.0, 0.0),label=\"Actual\", alpha=0.9, s=10)\n", + " ax.scatter(\n", + " -jnp.sin(sampled_params[:, 2]),\n", + " -jnp.cos(sampled_params[:, 2]),\n", + " label=\"Posterior Samples\",\n", + " alpha=0.5,\n", + " s=15,\n", + " )\n", + " ax.scatter(\n", + " -jnp.sin(actual_params[2]),\n", + " -jnp.cos(actual_params[2]),\n", + " color=(1.0, 0.0, 0.0),\n", + " label=\"Actual\",\n", + " alpha=0.9,\n", + " s=10,\n", + " )\n", " ax.set_title(\"Posterior on Orientation (top view)\")\n", " ax.legend(fontsize=7)\n", " # plt.show()\n", - " plt.savefig(f'{experiment_iteration:05d}.png')\n", + " plt.savefig(f\"{experiment_iteration:05d}.png\")\n", " plt.clf()" ] }, @@ -275,112 +315,160 @@ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from ipywidgets import interact, interactive, FloatSlider, IntSlider, Button, Output, HBox, VBox, FloatLogSlider\n", + "from ipywidgets import (\n", + " interact,\n", + " interactive,\n", + " FloatSlider,\n", + " IntSlider,\n", + " Button,\n", + " Output,\n", + " HBox,\n", + " VBox,\n", + " FloatLogSlider,\n", + ")\n", "\n", - "out = Output(layout={'border': '5px solid black', \"height\" : '100px'})\n", + "out = Output(layout={\"border\": \"5px solid black\", \"height\": \"100px\"})\n", "\n", - "def func(x,y,ang,variance, outlier_prob, outlier_volume):\n", "\n", + "def func(x, y, ang, variance, outlier_prob, outlier_volume):\n", " VARIANCE_GRID = jnp.array([variance])\n", " OUTLIER_GRID = jnp.array([outlier_prob])\n", " OUTLIER_VOLUME = outlier_volume\n", "\n", - " weight, gt_trace = importance_jit(key, genjax.choice_map({\n", - " \"parent_0\": -1,\n", - " \"parent_1\": 0,\n", - " \"id_0\": jnp.int32(21),\n", - " \"id_1\": jnp.int32(13),\n", - " \"camera_pose\": jnp.eye(4),\n", - " \"root_pose_0\": table_pose,\n", - " \"contact_params_1\": jnp.array([ x,y,ang ]),\n", - " \"face_parent_1\": 2,\n", - " \"face_child_1\": 3,\n", - " \"variance\": VARIANCE_GRID[0],\n", - " \"outlier_prob\": OUTLIER_GRID[0],\n", - " }), (\n", - " jnp.arange(2),\n", - " jnp.arange(22),\n", - " jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),\n", - " jnp.array([jnp.array([-0.1, -0.1, -2*jnp.pi]), jnp.array([0.1, 0.1, 2*jnp.pi])]),\n", - " b.RENDERER.model_box_dims, OUTLIER_VOLUME)\n", - " )\n", + " weight, gt_trace = importance_jit(\n", + " key,\n", + " genjax.choice_map(\n", + " {\n", + " \"parent_0\": -1,\n", + " \"parent_1\": 0,\n", + " \"id_0\": jnp.int32(21),\n", + " \"id_1\": jnp.int32(13),\n", + " \"camera_pose\": jnp.eye(4),\n", + " \"root_pose_0\": table_pose,\n", + " \"contact_params_1\": jnp.array([x, y, ang]),\n", + " \"face_parent_1\": 2,\n", + " \"face_child_1\": 3,\n", + " \"variance\": VARIANCE_GRID[0],\n", + " \"outlier_prob\": OUTLIER_GRID[0],\n", + " }\n", + " ),\n", + " (\n", + " jnp.arange(2),\n", + " jnp.arange(22),\n", + " jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]),\n", + " jnp.array(\n", + " [\n", + " jnp.array([-0.1, -0.1, -2 * jnp.pi]),\n", + " jnp.array([0.1, 0.1, 2 * jnp.pi]),\n", + " ]\n", + " ),\n", + " b.RENDERER.model_box_dims,\n", + " OUTLIER_VOLUME,\n", + " ),\n", + " )\n", " trace = gt_trace\n", " contact_param_grid = final_contact_param_deltas + trace[f\"contact_params_1\"]\n", - " weights = jnp.concatenate([\n", + " weights = jnp.concatenate(\n", + " [\n", " contact_enumerators[1].enumerate_choices_get_scores(\n", - " trace,\n", - " key,\n", - " d + trace[f\"contact_params_1\"],\n", - " VARIANCE_GRID,\n", - " OUTLIER_GRID\n", - " ) for d in jnp.array_split(final_contact_param_deltas, 55)\n", - " ],axis=0\n", + " trace, key, d + trace[f\"contact_params_1\"], VARIANCE_GRID, OUTLIER_GRID\n", + " )\n", + " for d in jnp.array_split(final_contact_param_deltas, 55)\n", + " ],\n", + " axis=0,\n", " )\n", - " \n", - " i,j,k = jnp.unravel_index(weights.argmax(), weights.shape)\n", - " trace= contact_enumerators[1].update_choices(\n", - " trace, key,\n", - " contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", + "\n", + " i, j, k = jnp.unravel_index(weights.argmax(), weights.shape)\n", + " trace = contact_enumerators[1].update_choices(\n", + " trace, key, contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]\n", " )\n", " print(trace[\"contact_params_1\"])\n", " print(trace[\"variance\"])\n", - " \n", + "\n", " print(trace.get_score())\n", - " \n", + "\n", " print(gt_trace[\"contact_params_1\"])\n", " print(gt_trace[\"variance\"])\n", " print(gt_trace.get_score())\n", - " \n", + "\n", " key2 = jax.random.PRNGKey(0)\n", " sampled_indices = jax.random.categorical(key2, weights.reshape(-1), shape=(1000,))\n", " sampled_indices = jnp.unravel_index(sampled_indices, weights.shape)[0]\n", " sampled_params = contact_param_grid[sampled_indices]\n", " actual_params = gt_trace[\"contact_params_1\"]\n", - " \n", + "\n", " fig = plt.figure(constrained_layout=True)\n", " widths = [1, 1]\n", " heights = [2]\n", - " spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,\n", - " height_ratios=heights)\n", - " \n", + " spec = fig.add_gridspec(\n", + " ncols=2, nrows=1, width_ratios=widths, height_ratios=heights\n", + " )\n", + "\n", " ax = fig.add_subplot(spec[0, 0])\n", - " ax.imshow(jnp.array(get_depth_image(trace[\"image\"][...,2])))\n", + " ax.imshow(jnp.array(get_depth_image(trace[\"image\"][..., 2])))\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", " ax.set_title(\"Observed Depth\")\n", - " \n", - " \n", + "\n", " ax = fig.add_subplot(spec[0, 1])\n", " ax.set_aspect(1.0)\n", - " circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle=\"--\", linewidth=0.5)\n", + " circ = plt.Circle(\n", + " (0, 0),\n", + " radius=1,\n", + " edgecolor=\"black\",\n", + " facecolor=\"None\",\n", + " linestyle=\"--\",\n", + " linewidth=0.5,\n", + " )\n", " ax.add_patch(circ)\n", " ax.set_xlim(-2.0, 2.0)\n", " ax.set_ylim(-2.0, 2.0)\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " ax.scatter(-jnp.sin(sampled_params[:,2]),-jnp.cos(sampled_params[:,2]),label=\"Posterior Samples\", alpha=0.5, s=15)\n", - " ax.scatter(-jnp.sin(actual_params[2]),-jnp.cos(actual_params[2]), color=(1.0, 0.0, 0.0),label=\"Actual\", alpha=0.9, s=10)\n", + " ax.scatter(\n", + " -jnp.sin(sampled_params[:, 2]),\n", + " -jnp.cos(sampled_params[:, 2]),\n", + " label=\"Posterior Samples\",\n", + " alpha=0.5,\n", + " s=15,\n", + " )\n", + " ax.scatter(\n", + " -jnp.sin(actual_params[2]),\n", + " -jnp.cos(actual_params[2]),\n", + " color=(1.0, 0.0, 0.0),\n", + " label=\"Actual\",\n", + " alpha=0.9,\n", + " s=10,\n", + " )\n", " ax.set_title(\"Posterior on Orientation (top view)\")\n", " ax.legend(fontsize=7)\n", " # plt.show()\n", " # plt.savefig(f'{experiment_iteration:05d}.png')\n", " # plt.clf()\n", - " \n", - " with out: \n", + "\n", + " with out:\n", " out.clear_output()\n", " display(f\"variance = {variance}\")\n", " display(f\"outlier_prob = {outlier_prob}\")\n", " display(f\"outlier_volume = {outlier_volume}\")\n", "\n", - "w = interactive(func, \n", - " x = FloatSlider(min=-0.1, max=0.1, value=0.0, description=\" x:\"),\n", - " y = FloatSlider(min=-0.1, max=0.1, value=0.0, description=\" y:\"),\n", - " ang = FloatSlider(min=-jnp.pi, max=jnp.pi, value=jnp.pi, description=\" ang:\"),\n", - " variance = FloatLogSlider(base=10.0, min=-9, max=1, value=0.000501, description=\"variance:\"),\n", - " outlier_prob = FloatLogSlider(base=10.0, min=-4, max=0, value=0.631, description=\"outlier_prob:\"),\n", - " outlier_volume = FloatLogSlider(base=10.0, min=1, max=5, value=10.0, description=\"outlier_volume:\")\n", - ");\n", - "display(VBox([w,out]))" + "\n", + "w = interactive(\n", + " func,\n", + " x=FloatSlider(min=-0.1, max=0.1, value=0.0, description=\" x:\"),\n", + " y=FloatSlider(min=-0.1, max=0.1, value=0.0, description=\" y:\"),\n", + " ang=FloatSlider(min=-jnp.pi, max=jnp.pi, value=jnp.pi, description=\" ang:\"),\n", + " variance=FloatLogSlider(\n", + " base=10.0, min=-9, max=1, value=0.000501, description=\"variance:\"\n", + " ),\n", + " outlier_prob=FloatLogSlider(\n", + " base=10.0, min=-4, max=0, value=0.631, description=\"outlier_prob:\"\n", + " ),\n", + " outlier_volume=FloatLogSlider(\n", + " base=10.0, min=1, max=5, value=10.0, description=\"outlier_volume:\"\n", + " ),\n", + ")\n", + "display(VBox([w, out]))" ] }, { diff --git a/scripts/experiments/tabletop/voxel_learning.ipynb b/scripts/experiments/tabletop/voxel_learning.ipynb index 69fd4dc8..170e8624 100644 --- a/scripts/experiments/tabletop/voxel_learning.ipynb +++ b/scripts/experiments/tabletop/voxel_learning.ipynb @@ -35,10 +35,12 @@ "metadata": {}, "outputs": [], "source": [ - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "obj_idx = 4\n", - "mesh_filename = os.path.join(model_dir,\"obj_\" + \"{}\".format(obj_idx+1).rjust(6, '0') + \".ply\")\n", - "SCALING_FACTOR = 1.0/1000.0" + "mesh_filename = os.path.join(\n", + " model_dir, \"obj_\" + \"{}\".format(obj_idx + 1).rjust(6, \"0\") + \".ply\"\n", + ")\n", + "SCALING_FACTOR = 1.0 / 1000.0" ] }, { @@ -49,11 +51,7 @@ "outputs": [], "source": [ "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=500.0, fy=500.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.01, far=50.0\n", + " height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=50.0\n", ")\n", "b.setup_renderer(intrinsics)\n", "b.RENDERER.add_mesh_from_file(mesh_filename, scaling_factor=SCALING_FACTOR)" @@ -66,12 +64,20 @@ "metadata": {}, "outputs": [], "source": [ - "object_poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.6, 0.6]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) for angle in jnp.linspace(-jnp.pi, jnp.pi, 7)[:-1]])\n", - "observations = b.RENDERER.render_many(object_poses[:,None,...], jnp.array([0]))" + "object_poses = jnp.array(\n", + " [\n", + " b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 0.6, 0.6]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + " )\n", + " @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " for angle in jnp.linspace(-jnp.pi, jnp.pi, 7)[:-1]\n", + " ]\n", + ")\n", + "observations = b.RENDERER.render_many(object_poses[:, None, ...], jnp.array([0]))" ] }, { @@ -81,7 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "b.hstack_images([b.get_depth_image(o[...,2]) for o in observations])" + "b.hstack_images([b.get_depth_image(o[..., 2]) for o in observations])" ] }, { @@ -92,10 +98,16 @@ "outputs": [], "source": [ "grid = b.utils.make_translation_grid_enumeration_3d(\n", - " -0.1, -0.1, -0.2,\n", - " 0.1, 0.1, 0.2,\n", + " -0.1,\n", + " -0.1,\n", + " -0.2,\n", + " 0.1,\n", + " 0.1,\n", + " 0.2,\n", " # 100, 100, 100\n", - " 60, 60, 60\n", + " 60,\n", + " 60,\n", + " 60,\n", ")\n", "b.show_cloud(\"grid\", grid)" ] @@ -107,7 +119,9 @@ "metadata": {}, "outputs": [], "source": [ - "voxel_occupied_occluded_free_parallel = jax.jit(jax.vmap(b.utils.voxel_occupied_occluded_free, in_axes=(0, 0, None, None, None)))" + "voxel_occupied_occluded_free_parallel = jax.jit(\n", + " jax.vmap(b.utils.voxel_occupied_occluded_free, in_axes=(0, 0, None, None, None))\n", + ")" ] }, { @@ -118,7 +132,7 @@ "outputs": [], "source": [ "occupancies = voxel_occupied_occluded_free_parallel(\n", - " b.inverse_pose(object_poses), observations[...,2], grid, intrinsics, 0.001\n", + " b.inverse_pose(object_poses), observations[..., 2], grid, intrinsics, 0.001\n", ")\n", "print(occupancies.sum())" ] @@ -131,7 +145,7 @@ "outputs": [], "source": [ "b.clear()\n", - "b.show_cloud(\"grid\", grid[(occupancies > 0.6).sum(0) > 0 ])\n", + "b.show_cloud(\"grid\", grid[(occupancies > 0.6).sum(0) > 0])\n", "# b.show_cloud(\"grid2\", grid[occupancy == 0.5],color=b.RED)" ] }, @@ -142,7 +156,9 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = b.utils.make_voxel_mesh_from_point_cloud(grid[(occupancies > 0.6).sum(0) > 0 ], 0.01 )" + "mesh = b.utils.make_voxel_mesh_from_point_cloud(\n", + " grid[(occupancies > 0.6).sum(0) > 0], 0.01\n", + ")" ] }, { @@ -203,7 +219,7 @@ "metadata": {}, "outputs": [], "source": [ - "viz.make_trimesh(mesh, jnp.eye(4), jnp.array([*distinct_colors[0], 1.0]))" + "viz.make_trimesh(mesh, jnp.eye(4), jnp.array([*distinct_colors[0], 1.0]))" ] }, { @@ -213,11 +229,19 @@ "metadata": {}, "outputs": [], "source": [ - "view_poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(\n", - " jnp.array([0.0, 0.9, 0.9]),\n", - " jnp.array([0.0, 0.0, 0.0]),\n", - " jnp.array([0.0, 0.0, 1.0]),\n", - " )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) for angle in jnp.linspace(-jnp.pi, jnp.pi, 100)[:-1]])" + "view_poses = jnp.array(\n", + " [\n", + " b.t3d.inverse_pose(\n", + " b.t3d.transform_from_pos_target_up(\n", + " jnp.array([0.0, 0.9, 0.9]),\n", + " jnp.array([0.0, 0.0, 0.0]),\n", + " jnp.array([0.0, 0.0, 1.0]),\n", + " )\n", + " )\n", + " @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)\n", + " for angle in jnp.linspace(-jnp.pi, jnp.pi, 100)[:-1]\n", + " ]\n", + ")" ] }, { @@ -239,7 +263,9 @@ "metadata": {}, "outputs": [], "source": [ - "b.viz.make_gif_from_pil_images([b.get_rgb_image(rgbd.rgb) for rgbd in images], \"out.gif\")" + "b.viz.make_gif_from_pil_images(\n", + " [b.get_rgb_image(rgbd.rgb) for rgbd in images], \"out.gif\"\n", + ")" ] }, { diff --git a/test/test_colmap.ipynb b/test/test_colmap.ipynb index bcc78a67..f2196c44 100644 --- a/test/test_colmap.ipynb +++ b/test/test_colmap.ipynb @@ -38,7 +38,9 @@ "outputs": [], "source": [ "movie_file_path = Path(\"/home/nishadgothoskar/bayes3d/assets/can.MOV\")\n", - "dataset_path = Path(b.utils.get_assets_dir()) / Path(movie_file_path.name + \"_colmap_dataset\")\n", + "dataset_path = Path(b.utils.get_assets_dir()) / Path(\n", + " movie_file_path.name + \"_colmap_dataset\"\n", + ")\n", "input_path = dataset_path / Path(\"input\")\n", "input_path.mkdir(parents=True, exist_ok=True)" ] @@ -63,10 +65,8 @@ "images = [b.viz.load_image_from_file(f) for f in image_paths]\n", "# b.make_gif_from_pil_images(images, \"input.gif\")\n", "(positions, colors, normals), train_cam_infos = b.colmap.readColmapSceneInfo(\n", - " dataset_path,\n", - " \"images\",\n", - " False\n", - ")\n" + " dataset_path, \"images\", False\n", + ")" ] }, { @@ -104,8 +104,8 @@ "]\n", "\n", "b.show_cloud(\"cloud\", positions * scaling_factor)\n", - "for (i,p) in enumerate(poses):\n", - " b.show_pose(f\"{i}\", p)\n" + "for i, p in enumerate(poses):\n", + " b.show_pose(f\"{i}\", p)" ] }, { diff --git a/test/test_jax_renderer.ipynb b/test/test_jax_renderer.ipynb index 3db223d7..31ae15b2 100644 --- a/test/test_jax_renderer.ipynb +++ b/test/test_jax_renderer.ipynb @@ -17,29 +17,26 @@ "from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer\n", "\n", "intrinsics = b.Intrinsics(\n", - " height=100,\n", - " width=100,\n", - " fx=75.0, fy=75.0,\n", - " cx=50.0, cy=50.0,\n", - " near=0.001, far=16.0\n", + " height=100, width=100, fx=75.0, fy=75.0, cx=50.0, cy=50.0, near=0.001, far=16.0\n", ")\n", "from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer\n", + "\n", "jax_renderer = JaxRenderer(intrinsics)\n", "\n", "\n", - "model_dir = os.path.join(b.utils.get_assets_dir(),\"bop/ycbv/models\")\n", + "model_dir = os.path.join(b.utils.get_assets_dir(), \"bop/ycbv/models\")\n", "idx = 17\n", - "mesh_path = os.path.join(model_dir,\"obj_\" + \"{}\".format(idx).rjust(6, '0') + \".ply\")\n", + "mesh_path = os.path.join(model_dir, \"obj_\" + \"{}\".format(idx).rjust(6, \"0\") + \".ply\")\n", "m = b.utils.load_mesh(mesh_path)\n", - "m = b.utils.scale_mesh(m, 1.0/100.0)\n", + "m = b.utils.scale_mesh(m, 1.0 / 100.0)\n", "\n", "vertices = jnp.array(m.vertices.astype(np.float32))\n", "faces = jnp.array(m.faces.astype(np.int32))\n", "\n", "pose = b.transform_from_pos(jnp.array([0.0, 0.0, 2.0]))\n", "NUM_POSES = 50\n", - "poses = jnp.tile(pose[None,...], (NUM_POSES,1,1))\n", - "poses = poses.at[:,0,3].set(jnp.linspace(-1.0, 1.0, NUM_POSES))" + "poses = jnp.tile(pose[None, ...], (NUM_POSES, 1, 1))\n", + "poses = poses.at[:, 0, 3].set(jnp.linspace(-1.0, 1.0, NUM_POSES))" ] }, { @@ -128,19 +125,39 @@ ], "source": [ "projection_matrix = b.camera._open_gl_projection_matrix(\n", - " intrinsics.height, intrinsics.width, \n", - " intrinsics.fx, intrinsics.fy, \n", - " intrinsics.cx, intrinsics.cy, \n", - " intrinsics.near, intrinsics.far\n", + " intrinsics.height,\n", + " intrinsics.width,\n", + " intrinsics.fx,\n", + " intrinsics.fy,\n", + " intrinsics.cx,\n", + " intrinsics.cy,\n", + " intrinsics.near,\n", + " intrinsics.far,\n", ")\n", "composed_projection = projection_matrix @ poses\n", - "vertices_homogenous = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1)\n", - "clip_spaces_projected_vertices = jnp.einsum(\"nij,mj->nmi\", composed_projection, vertices_homogenous)\n", - "rast_out, rast_out_db = jax_renderer.rasterize(clip_spaces_projected_vertices, faces, jnp.array([intrinsics.height, intrinsics.width]))\n", - "interpolated_collided_vertices_clip, _ = jax_renderer.interpolate(jnp.tile(vertices_homogenous[None,...],(poses.shape[0],1,1)), rast_out, faces, rast_out_db, jnp.array([0,1,2,3]))\n", - "interpolated_collided_vertices = jnp.einsum(\"a...ij,a...j->a...i\", poses, interpolated_collided_vertices_clip)\n", - "mask = rast_out[...,-1] > 0\n", - "depth = interpolated_collided_vertices[...,2] * mask\n", + "vertices_homogenous = jnp.concatenate(\n", + " [vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1\n", + ")\n", + "clip_spaces_projected_vertices = jnp.einsum(\n", + " \"nij,mj->nmi\", composed_projection, vertices_homogenous\n", + ")\n", + "rast_out, rast_out_db = jax_renderer.rasterize(\n", + " clip_spaces_projected_vertices,\n", + " faces,\n", + " jnp.array([intrinsics.height, intrinsics.width]),\n", + ")\n", + "interpolated_collided_vertices_clip, _ = jax_renderer.interpolate(\n", + " jnp.tile(vertices_homogenous[None, ...], (poses.shape[0], 1, 1)),\n", + " rast_out,\n", + " faces,\n", + " rast_out_db,\n", + " jnp.array([0, 1, 2, 3]),\n", + ")\n", + "interpolated_collided_vertices = jnp.einsum(\n", + " \"a...ij,a...j->a...i\", poses, interpolated_collided_vertices_clip\n", + ")\n", + "mask = rast_out[..., -1] > 0\n", + "depth = interpolated_collided_vertices[..., 2] * mask\n", "print(depth.shape)\n", "b.get_depth_image(depth[0])" ]