-
Notifications
You must be signed in to change notification settings - Fork 20
/
kalman_filter.py
81 lines (64 loc) · 2.55 KB
/
kalman_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Example: Kalman Filter
======================
"""
import argparse
import torch
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.interpreter import reinterpret
from funsor.optimizer import apply_optimizer
def main(args):
funsor.set_backend("torch")
# Declare parameters.
trans_noise = torch.tensor(0.1, requires_grad=True)
emit_noise = torch.tensor(0.5, requires_grad=True)
params = [trans_noise, emit_noise]
# A Gaussian HMM model.
def model(data):
log_prob = funsor.to_funsor(0.0)
x_curr = funsor.Tensor(torch.tensor(0.0))
for t, y in enumerate(data):
x_prev = x_curr
# A delayed sample statement.
x_curr = funsor.Variable("x_{}".format(t), funsor.Real)
log_prob += dist.Normal(1 + x_prev / 2.0, trans_noise, value=x_curr)
# Optionally marginalize out the previous state.
if t > 0 and not args.lazy:
log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)
# An observe statement.
log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)
# Marginalize out all remaining delayed variables.
log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
# Train model parameters.
torch.manual_seed(0)
data = torch.randn(args.time_steps)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
if args.lazy:
with funsor.interpretations.lazy:
log_prob = apply_optimizer(model(data))
log_prob = reinterpret(log_prob)
else:
log_prob = model(data)
assert not log_prob.inputs, "free variables remain"
loss = -log_prob.data
loss.backward()
optim.step()
if args.verbose and step % 10 == 0:
print("step {} loss = {}".format(step, loss.item()))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Kalman filter example")
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-n", "--train-steps", default=101, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.05, type=float)
parser.add_argument("--lazy", action="store_true")
parser.add_argument("--filter", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
main(args)