forked from patrick-kidger/diffrax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
path.py
248 lines (218 loc) · 8.85 KB
/
path.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import math
from typing import cast, Optional, Union
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax.internal as lxi
from jaxtyping import Array, PRNGKeyArray, PyTree
from lineax.internal import complex_to_real_dtype
from .._custom_types import (
AbstractBrownianIncrement,
BrownianIncrement,
DavieFosterWeakSpaceSpaceLevyArea,
DavieWeakSpaceSpaceLevyArea,
levy_tree_transpose,
RealScalarLike,
SpaceTimeLevyArea,
SpaceTimeTimeLevyArea,
)
from .._misc import (
force_bitcast_convert_type,
is_tuple_of_ints,
split_by_tree,
)
from .base import AbstractBrownianPath
_Levy_Areas = Union[
BrownianIncrement,
SpaceTimeLevyArea,
SpaceTimeTimeLevyArea,
DavieWeakSpaceSpaceLevyArea,
DavieFosterWeakSpaceSpaceLevyArea,
]
class UnsafeBrownianPath(AbstractBrownianPath):
"""Brownian simulation that is only suitable for certain cases.
This is a very quick way to simulate Brownian motion, but can only be used when all
of the following are true:
1. You are using a fixed step size controller. (Not an adaptive one.)
2. You do not need to backpropagate through the differential equation.
3. You do not need deterministic solutions with respect to `key`. (This
implementation will produce different results based on fluctuations in
floating-point arithmetic.)
Internally this operates by just sampling a fresh normal random variable over every
interval, ignoring the correlation between samples exhibited in true Brownian
motion. Hence the restrictions above. (They describe the general case for which the
correlation structure isn't needed.)
!!! info "Lévy Area"
Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or
`diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it
also computes space-time Lévy area `H`. This is an additional source of
randomness required for certain stochastic Runge--Kutta solvers; see
[`diffrax.AbstractSRK`][] for more information.
An error will be thrown during tracing if Lévy area is required but is not
available.
The choice here will impact the Brownian path, so even with the same key, the
trajectory will be different depending on the value of `levy_area`.
"""
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: type[_Levy_Areas] = eqx.field(static=True)
key: PRNGKeyArray
def __init__(
self,
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: PRNGKeyArray,
levy_area: type[_Levy_Areas] = BrownianIncrement,
):
self.shape = (
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
if is_tuple_of_ints(shape)
else shape
)
self.key = key
self.levy_area = levy_area
if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
for x in jtu.tree_leaves(self.shape)
):
raise ValueError("UnsafeBrownianPath dtypes all have to be floating-point.")
@property
def t0(self):
return -jnp.inf
@property
def t1(self):
return jnp.inf
@eqx.filter_jit
def evaluate(
self,
t0: RealScalarLike,
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
k: Optional[PRNGKeyArray] = None,
) -> Union[PyTree[Array], AbstractBrownianIncrement]:
del left
if t1 is None:
dtype = jnp.result_type(t0)
t1 = t0
t0 = jnp.array(0, dtype)
else:
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(t0, t1)
t0 = jnp.astype(t0, dtype)
t1 = jnp.astype(t1, dtype)
t0 = eqxi.nondifferentiable(t0, name="t0")
t1 = eqxi.nondifferentiable(t1, name="t1")
t1 = cast(RealScalarLike, t1)
if k is None:
t0_ = force_bitcast_convert_type(t0, jnp.int32)
t1_ = force_bitcast_convert_type(t1, jnp.int32)
key = jr.fold_in(self.key, t0_)
key = jr.fold_in(key, t1_)
else:
key = k
key = split_by_tree(key, self.shape)
out = jtu.tree_map(
lambda key, shape: self._evaluate_leaf(
t0, t1, key, shape, self.levy_area, use_levy
),
key,
self.shape,
)
if use_levy:
out = levy_tree_transpose(self.shape, out)
assert isinstance(out, self.levy_area)
return out
@staticmethod
def _evaluate_leaf(
t0: RealScalarLike,
t1: RealScalarLike,
key,
shape: jax.ShapeDtypeStruct,
levy_area: type[_Levy_Areas],
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))
if levy_area is SpaceTimeTimeLevyArea:
key_w, key_hh, key_kk = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
kk_std = w_std / math.sqrt(720)
kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std
levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk)
elif levy_area is DavieWeakSpaceSpaceLevyArea:
key_w, key_hh, key_b = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
if w.ndim == 0 or w.ndim == 1:
a = jnp.zeros_like(w, dtype=shape.dtype)
levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
else:
b_std = (dt / jnp.sqrt(12)).astype(shape.dtype)
b = (
jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype)
* b_std
)
b = b - b.transpose(*range(b.ndim - 2), -1, -2)
a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims(
w, -1
) * jnp.expand_dims(hh, -2)
a += b
levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
elif levy_area is DavieFosterWeakSpaceSpaceLevyArea:
key_w, key_hh, key_b = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
if w.ndim == 0 or w.ndim == 1:
a = jnp.zeros_like(w, dtype=shape.dtype)
levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
else:
tenth_dt = (0.1 * dt).astype(shape.dtype)
hh_squared = hh**2
b_std = jnp.sqrt(
tenth_dt
* (
tenth_dt
+ jnp.expand_dims(hh_squared, -1)
+ jnp.expand_dims(hh_squared, -2)
)
).astype(shape.dtype)
b = (
jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype)
* b_std
)
b = b - b.transpose(*range(b.ndim - 2), -1, -2)
a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims(
w, -1
) * jnp.expand_dims(hh, -2)
a += b
levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
elif levy_area is SpaceTimeLevyArea:
key_w, key_hh = jr.split(key, 2)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh)
elif levy_area is BrownianIncrement:
w = jr.normal(key, shape.shape, shape.dtype) * w_std
levy_val = BrownianIncrement(dt=dt, W=w)
else:
assert False
if use_levy:
return levy_val
return w
UnsafeBrownianPath.__init__.__doc__ = """
**Arguments:**
- `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape,
dtype, and PyTree structure of the output. For simplicity, `shape` can also just
be a tuple of integers, describing the shape of a single JAX array. In that case
the dtype is chosen to be the default floating-point dtype.
- `key`: A random key.
- `levy_area`: Whether to additionally generate Lévy area. This is required by some SDE
solvers.
"""