-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathnested_vmap_grad.py
72 lines (63 loc) · 2.93 KB
/
nested_vmap_grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
demonstration of vmap + grad like API
"""
import sys
sys.path.insert(0, "../")
import tensorcircuit as tc
# See issues in https://github.com/tencent-quantum-lab/tensorcircuit/issues/229#issuecomment-2600773780
for backend in ["tensorflow", "jax"]:
with tc.runtime_backend(backend) as K:
L = 2
inputs = K.cast(K.ones([3, 2]), tc.rdtypestr)
weights = K.cast(K.ones([3, 2]), tc.rdtypestr)
def ansatz(thetas, alpha):
c = tc.Circuit(L)
for j in range(2):
for i in range(L):
c.rx(i, theta=thetas[j])
c.ry(i, theta=alpha[j])
for i in range(L - 1):
c.cnot(i, i + 1)
return c
def f(thetas, alpha):
c = ansatz(thetas, alpha)
observables = K.stack([K.real(c.expectation_ps(z=[i])) for i in range(L)])
return K.mean(observables)
print("grad_0", K.grad(f)(inputs[0], weights[0]))
print("grad_1", K.grad(f, argnums=1)(inputs[0], weights[0]))
print("vmap_0", K.vmap(f)(inputs, weights[0]))
print("vmap_1", K.vmap(f, vectorized_argnums=1)(inputs[0], weights))
print("vmap over grad_0", K.vmap(K.grad(f))(inputs, weights[0]))
# wrong in tf due to https://github.com/google/TensorNetwork/issues/940
# https://github.com/tensorflow/tensorflow/issues/52148
print("vmap over grad_1", K.vmap(K.grad(f, argnums=1))(inputs, weights[0]))
# wrong in tf
print("vmap over jacfwd_0", K.vmap(K.jacfwd(f))(inputs, weights[0]))
print("jacfwd_0 over vmap", K.jacfwd(K.vmap(f))(inputs, weights[0]))
print("vmap over jacfwd_1", K.vmap(K.jacfwd(f, argnums=1))(inputs, weights[0]))
print("jacfwd_1 over vmap", K.jacfwd(K.vmap(f), argnums=1)(inputs, weights[0]))
r = K.vmap(K.jacrev(f))(inputs, weights[0])
print("vmap over jacrev0", r)
# wrong in tf
r = K.jacrev(K.vmap(f))(inputs, weights[0])
print("jacrev0 over vmap", r)
r = K.vmap(K.jacrev(f, argnums=1))(inputs, weights[0])
print("vmap over jacrev1", r)
# wrong in tf
r = K.jacrev(K.vmap(f), argnums=1)(inputs, weights[0])
print("jacrev1 over vmap", r)
r = K.vmap(K.jacrev(f, argnums=1), vectorized_argnums=1)(inputs[0], weights)
print("vmap1 over jacrev1", r)
r = K.jacrev(K.vmap(f, vectorized_argnums=1), argnums=1)(inputs[0], weights)
print("jacrev1 over vmap1", r)
r = K.vmap(K.hessian(f))(inputs, weights[0])
print("vmap over hess0", r)
# wrong in tf
r = K.hessian(K.vmap(f))(inputs, weights[0])
print("hess0 over vmap", r)
r = K.vmap(K.hessian(f, argnums=1))(inputs, weights[0])
print("vmap over hess1", r)
# wrong in tf
r = K.hessian(K.vmap(f), argnums=1)(inputs, weights[0])
print("hess1 over vmap", r)
# lessons: never put vmap outside gradient function in tf