-
Notifications
You must be signed in to change notification settings - Fork 938
/
ode_demo.py
182 lines (135 loc) · 5.51 KB
/
ode_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
parser = argparse.ArgumentParser('ODE demo')
parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5')
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--niters', type=int, default=2000)
parser.add_argument('--test_freq', type=int, default=20)
parser.add_argument('--viz', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--adjoint', action='store_true')
args = parser.parse_args()
if args.adjoint:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
true_y0 = torch.tensor([[2., 0.]]).to(device)
t = torch.linspace(0., 25., args.data_size).to(device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)
class Lambda(nn.Module):
def forward(self, t, y):
return torch.mm(y**3, true_A)
with torch.no_grad():
true_y = odeint(Lambda(), true_y0, t, method='dopri5')
def get_batch():
s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
batch_y0 = true_y[s] # (M, D)
batch_t = t[:args.batch_time] # (T)
batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
return batch_y0.to(device), batch_t.to(device), batch_y.to(device)
def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
if args.viz:
makedirs('png')
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12, 4), facecolor='white')
ax_traj = fig.add_subplot(131, frameon=False)
ax_phase = fig.add_subplot(132, frameon=False)
ax_vecfield = fig.add_subplot(133, frameon=False)
plt.show(block=False)
def visualize(true_y, pred_y, odefunc, itr):
if args.viz:
ax_traj.cla()
ax_traj.set_title('Trajectories')
ax_traj.set_xlabel('t')
ax_traj.set_ylabel('x,y')
ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')
ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')
ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
ax_traj.set_ylim(-2, 2)
ax_traj.legend()
ax_phase.cla()
ax_phase.set_title('Phase Portrait')
ax_phase.set_xlabel('x')
ax_phase.set_ylabel('y')
ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
ax_phase.set_xlim(-2, 2)
ax_phase.set_ylim(-2, 2)
ax_vecfield.cla()
ax_vecfield.set_title('Learned Vector Field')
ax_vecfield.set_xlabel('x')
ax_vecfield.set_ylabel('y')
y, x = np.mgrid[-2:2:21j, -2:2:21j]
dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
dydt = (dydt / mag)
dydt = dydt.reshape(21, 21, 2)
ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
ax_vecfield.set_xlim(-2, 2)
ax_vecfield.set_ylim(-2, 2)
fig.tight_layout()
plt.savefig('png/{:03d}'.format(itr))
plt.draw()
plt.pause(0.001)
class ODEFunc(nn.Module):
def __init__(self):
super(ODEFunc, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Tanh(),
nn.Linear(50, 2),
)
for m in self.net.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=0.1)
nn.init.constant_(m.bias, val=0)
def forward(self, t, y):
return self.net(y**3)
class RunningAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()
def reset(self):
self.val = None
self.avg = 0
def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val
if __name__ == '__main__':
ii = 0
func = ODEFunc().to(device)
optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
end = time.time()
time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)
for itr in range(1, args.niters + 1):
optimizer.zero_grad()
batch_y0, batch_t, batch_y = get_batch()
pred_y = odeint(func, batch_y0, batch_t).to(device)
loss = torch.mean(torch.abs(pred_y - batch_y))
loss.backward()
optimizer.step()
time_meter.update(time.time() - end)
loss_meter.update(loss.item())
if itr % args.test_freq == 0:
with torch.no_grad():
pred_y = odeint(func, true_y0, t)
loss = torch.mean(torch.abs(pred_y - true_y))
print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
visualize(true_y, pred_y, func, ii)
ii += 1
end = time.time()