-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathfit-hmc-ad.dx
95 lines (75 loc) · 2.44 KB
/
fit-hmc-ad.dx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
'# HMC using Dex with auto-differentiated gradients
-- load some generic utility functions
import djwutils
'## now read and process the data
dat = unsafe_io \. read_file "../pima.data"
AsList(_, tab) = parse_tsv ' ' dat
atab = map (\l. cons "1.0" l) tab
att = map (\r. list2tab r :: (Fin 9)=>String) atab
xStr = map (\r. slice r 0 (Fin 8)) att
xmb = map (\r. map parseString r) xStr :: _=>(Fin 8)=>(Maybe Float)
x = map (\r. map from_just r) xmb :: _=>(Fin 8)=>Float
yStrM = map (\r. slice r 8 (Fin 1)) att
yStr = (transpose yStrM)[0@_]
y = map (\s. select (s == "Yes") 1.0 0.0) yStr
x
y
'## now set up for MCMC
def ll(b: (Fin 8)=>Float) -> Float =
neg $ sum (log (map (\ x. (exp x) + 1) ((map (\ yi. 1 - 2*yi) y)*(x **. b))))
pscale = [10.0, 1, 1, 1, 1, 1, 1, 1] -- prior SDs
prscale = map (\ x. 1.0/x) pscale
def lprior(b: (Fin 8)=>Float) -> Float =
bs = b*prscale
neg $ sum ((log pscale) + (0.5 .* (bs*bs)))
def lpost(b: (Fin 8)=>Float) -> Float =
(ll b) + (lprior b)
k = new_key 42
-- Metropolis with deterministic proposal
def mdKernel(lpost: (s) -> Float, prop: (s) -> s,
x0: s, k: Key) -> s given (s) =
x = prop x0
ll0 = lpost x0
ll = lpost x
a = ll - ll0
u = rand k
select (log u < a) x x0
def hmcKernel(lpi: (Fin n=>Float) -> Float,
dmm: (Fin n)=>Float, eps: Float, l: Nat,
q0: (Fin n)=>Float, k: Key) -> (Fin n)=>Float given (n) =
sdmm = sqrt dmm
idmm = map (\x. 1.0/x) dmm
glpi = \x. grad lpi x
def leapf(q0: (Fin n)=>Float, p0: (Fin n)=>Float) ->
((Fin n=>Float), (Fin n)=>Float) =
p1 = p0 + (eps/2) .* (glpi q0)
q1 = q0 + eps .* (p1*idmm)
(q, p) = apply_n l (q1, p1) \qpo.
(qo, po) = qpo
pn = po + eps .* (glpi qo)
qn = qo + eps .* (pn*idmm)
(qn, pn)
pf = p + (eps/2) .* (glpi q)
(q, -pf)
def alpi(qp: ((Fin n=>Float), (Fin n)=>Float)) -> Float =
(q, p) = qp
(lpi q) - 0.5*(sum (p*p*idmm))
def prop(qp: ((Fin n=>Float), (Fin n)=>Float)) ->
((Fin n=>Float), (Fin n)=>Float) =
(q, p) = qp
leapf q p
mk = \s k. mdKernel alpi prop s k
[k1, k2] = split_key k
z = randn_vec k1
p0 = sdmm * z
(q, p) = mk (q0, p0) k2
q
pre = [100.0, 1, 1, 1, 1, 1, 25, 1]
kern = \s k. hmcKernel lpost (map (\x. 1.0/x) pre) 1.0e-3 49 s k
init = [-9.0,0,0,0,0,0,0,0]
mat = markov_chain init (\s k. step_n 20 kern s k) 10000 k
mv = meanAndCovariance mat
fst mv -- mean
snd mv -- (co)variance matrix
unsafe_io \. write_file "fit-hmc-ad.tsv" (to_tsv mat)
-- eof