-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhash_encoding.py
95 lines (79 loc) · 3.13 KB
/
hash_encoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import functools as ft
def hash_vertex(v, hashmap_size):
primes = jnp.array([1, 2654435761, 805459861], dtype=np.uint32)
h = np.uint32(0)
for i in range(len(v)):
h ^= v[i] * primes[i]
return h % hashmap_size
def interpolate_bilinear(values, weights):
assert weights.shape == (2,)
c0 = values[0]*(1.0-weights[0]) + values[2]*weights[0]
c1 = values[1]*(1.0-weights[0]) + values[3]*weights[0]
c = c0*(1.0-weights[1]) + c1*weights[1]
return c
def interpolate_trilinear(values, weights):
# https://en.wikipedia.org/wiki/Trilinear_interpolation
assert weights.shape == (3,)
c00 = values[0]*(1.0-weights[0]) + values[4]*weights[0]
c01 = values[1]*(1.0-weights[0]) + values[5]*weights[0]
c10 = values[2]*(1.0-weights[0]) + values[6]*weights[0]
c11 = values[3]*(1.0-weights[0]) + values[7]*weights[0]
c0 = c00*(1.0-weights[1]) + c10*weights[1]
c1 = c01*(1.0-weights[1]) + c11*weights[1]
c = c0*(1.0-weights[2]) + c1*weights[2]
return c
def interpolate_dlinear(values, weights):
dim, = weights.shape
if dim == 2: return interpolate_bilinear(values, weights)
elif dim == 3: return interpolate_trilinear(values, weights)
else: assert False
def unit_box(dim: int):
if dim == 2: return np.array([[i,j] for i in (0,1) for j in (0,1)], dtype=np.uint32)
elif dim == 3: return np.array([[i,j,k] for i in (0,1) for j in (0,1) for k in (0,1)], dtype=np.uint32)
else: assert False
@ft.partial(jax.jit, static_argnames=("nmin", "nmax"))
def encode(x, theta, nmin=16, nmax=512):
"""Multiresolution Hash Encoding.
Following the paper:
Instant Neural Graphics Primitives with a Multiresolution Hash Encoding
Thomas Müller, Alex Evans, Christoph Schied, Alexander Keller
ACM Transactions on Graphics (SIGGRAPH), July 2022
The present code takes only a single input vector in 2D or 3D to encode.
If you need to encode large batches of inputs jointly, consider
wrapping this function with `jax.vmap`.
Args:
x: Float[input_dim]
theta: Float[levels, hashmap_size, features_per_entry]
Returns:
Float[levels, features_per_entry]
"""
input_dim, = x.shape
levels, hashmap_size, features_per_entry = theta.shape
box = unit_box(input_dim)
b = np.exp((np.log(nmax) - np.log(nmin)) / (levels - 1))
def features(l):
nl = jnp.floor(nmin * b**l)
xl = x * nl
xl_ = jnp.floor(xl).astype(np.uint32)
# hash voxel vertices
indices = jax.vmap(lambda v: hash_vertex(xl_ + v, hashmap_size))(box)
# lookup
tl = theta[l][indices]
# interpolate
wl = (xl - xl_)
xi = interpolate_dlinear(tl, wl)
return xi
return jax.lax.map(features, np.arange(levels, dtype=np.uint32))
def init_encoding(
key,
levels: int=16,
hashmap_size_log2: int=14,
features_per_entry: int=2,
):
hashmap_size = 1 << hashmap_size_log2
theta = jrandom.uniform(key, (levels, hashmap_size, features_per_entry), minval=-0.0001, maxval=0.0001)
return theta