Skip to content

Commit 560ecca

Browse files
Merge pull request #709 from rhayes777/features/nested
Features/nested
2 parents dd0a2ec + 32a09d4 commit 560ecca

File tree

9 files changed

+446
-69
lines changed

9 files changed

+446
-69
lines changed

autofit/graphical/factor_graphs/abstract.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
FlattenArrays,
1919
nested_filter,
2020
nested_update,
21+
nested_zip,
22+
nested_set,
23+
nested_map,
24+
nested_items,
2125
is_variable,
2226
Status,
2327
)
@@ -92,7 +96,7 @@ def resolve_variable_dict(
9296
def resolve_args(
9397
self, values: Dict[Variable, np.ndarray]
9498
) -> Tuple[np.ndarray, ...]:
95-
return (values[k] for k in self.args)
99+
return nested_update(self.args, values)
96100

97101
@cached_property
98102
def fixed_values(self) -> VariableData:
@@ -103,7 +107,7 @@ def variables(self) -> Set[Variable]:
103107
"""
104108
Dictionary mapping the names of variables to those variables
105109
"""
106-
return frozenset(self._kwargs.values())
110+
return frozenset(self.flat_args)
107111

108112
@property
109113
def free_variables(self) -> Set[Variable]:
@@ -121,12 +125,16 @@ def kwargs(self, kwargs):
121125
self._kwargs = kwargs
122126

123127
@property
124-
def args(self) -> Tuple[Variable, ...]:
128+
def args(self) -> Tuple[Any, ...]:
125129
return tuple(self.kwargs.values())
126130

127131
@property
128132
def arg_names(self) -> Tuple[str, ...]:
129133
return tuple(self.kwargs)
134+
135+
@property
136+
def flat_args(self) -> Tuple[Variable, ...]:
137+
return tuple(x for x, in nested_zip(self._kwargs))
130138

131139
@property
132140
def factor_out(self):
@@ -245,7 +253,7 @@ def _unique_representation(self):
245253
return (
246254
self._factor,
247255
self.arg_names,
248-
self.args,
256+
self.flat_args,
249257
self.deterministic_variables,
250258
)
251259

@@ -272,19 +280,19 @@ def _numerical_factor_jacobian(
272280
factor._factor(*args), jax.jacobian(factor._factor, range(len(args)))(*args)
273281
"""
274282
eps = eps or self.eps
275-
args = tuple(np.array(value, dtype=np.float64) for value in args)
283+
284+
args = nested_map(lambda _, val: np.array(val, dtype=np.float64), self.args, args)
276285

277286
raw_fval0 = self._factor_args(*args)
278287
fval0 = self._factor_value(raw_fval0).to_dict()
279288

280289
jac = {
281-
v0: tuple(
282-
np.empty_like(val, shape=np.shape(val) + np.shape(value))
283-
for value in args
284-
)
285-
for v0, val in fval0.items()
290+
v0: nested_map(
291+
lambda _, v: np.empty_like(val, shape=np.shape(val) + np.shape(v)),
292+
self.args, args
293+
) for v0, val in fval0.items()
286294
}
287-
for i, val in enumerate(args):
295+
for ks, _, val in nested_items(self.args, args):
288296
with np.nditer(val, op_flags=["readwrite"], flags=["multi_index"]) as it:
289297
for x_i in it:
290298
val[it.multi_index] += eps
@@ -293,7 +301,9 @@ def _numerical_factor_jacobian(
293301
x_i -= eps
294302
indexes = (Ellipsis,) + it.multi_index
295303
for v0, jac_v0v_i in jac_v1_i.items():
296-
jac[v0][i][indexes] = jac_v0v_i
304+
key_path = (v0, *ks, indexes)
305+
nested_set(jac, key_path, jac_v0v_i)
306+
# jac[v0][i][indexes] = jac_v0v_i
297307

298308
# This replicates the output of normal
299309
# jax.jacobian(self.factor, len(self.args))(*args)
@@ -304,7 +314,7 @@ def _numerical_factor_jacobian(
304314
def numerical_func_jacobian(
305315
self, values: VariableData, **kwargs
306316
) -> tuple:
307-
args = (values[k] for k in self.args)
317+
args = self.resolve_args(values)
308318
raw_fval, raw_jac = self._numerical_factor_jacobian(*args, **kwargs)
309319
fval = self._factor_value(raw_fval)
310320
jvp = self._jac_out_to_jvp(raw_jac, values=fval.to_dict().merge(values))

autofit/graphical/factor_graphs/factor.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from autofit.graphical.utils import (
1515
nested_filter,
16+
to_variabledata,
17+
nested_zip,
1618
is_variable,
1719
try_getitem,
1820
)
@@ -191,7 +193,6 @@ def __init__(
191193
**kwargs,
192194
)
193195

194-
# self.factor_out = factor_out
195196
self.eps = eps
196197
self._set_factor(factor)
197198
self._set_jacobians(
@@ -319,16 +320,16 @@ def _factor_value(self, raw_fval) -> FactorValue:
319320
where the values of the deterministic values are stored in a dict
320321
attribute `FactorValue.deterministic_values`
321322
"""
322-
det_values = VariableData(nested_filter(is_variable, self.factor_out, raw_fval))
323+
det_values = to_variabledata(self.factor_out, raw_fval)
323324
fval = det_values.pop(FactorValue, 0.0)
324325
return FactorValue(fval, det_values)
325326

326327
def __call__(self, values: VariableData) -> FactorValue:
327328
"""Calls the factor with the values specified by the dictionary of
328329
values passed, returns a FactorValue with the value returned by the
329330
factor, and any deterministic factors"""
330-
args = [values[v] for v in self.args]
331-
key = self._key("__call__", *args)
331+
args = self.resolve_args(values)
332+
key = self._key("__call__", *(val for _, val in nested_zip(self.args, args)))
332333

333334
if key not in self._cache:
334335
raw_fval = self._factor_args(*args)
@@ -351,7 +352,7 @@ def _vjp_func_jacobian(
351352
from autofit.graphical.factor_graphs.jacobians import (
352353
VectorJacobianProduct,
353354
)
354-
raw_fval, fvjp = self._factor_vjp(*(values[v] for v in self.args))
355+
raw_fval, fvjp = self._factor_vjp(*self.resolve_args(values))
355356
fval = self._factor_value(raw_fval)
356357

357358
fvjp_op = VectorJacobianProduct(
@@ -380,7 +381,7 @@ def _key(*args):
380381
def _jvp_func_jacobian(
381382
self, values: VariableData, **kwargs
382383
) -> Tuple[FactorValue, "JacobianVectorProduct"]:
383-
args = list(values[k] for k in self.args)
384+
args = self.resolve_args(values)
384385
key = self._key("_jvp_func_jacobian", *args)
385386

386387
if key not in self._cache:
@@ -402,7 +403,7 @@ def _unpack_jacobian_out(self, raw_jac: Any) -> Dict[Variable, VariableData]:
402403
jac = {}
403404
for v0, vjac in nested_filter(is_variable, self.factor_out, raw_jac):
404405
jac[v0] = VariableData()
405-
for v1, j in zip(self.args, vjac):
406+
for v1, j in nested_zip(self.args, vjac):
406407
jac[v0][v1] = j
407408

408409
return jac

autofit/graphical/factor_graphs/jacobians.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
nested_filter,
2121
nested_update,
2222
is_variable,
23+
to_variabledata,
2324
)
2425
from autofit.mapper.variable import (
2526
Variable,
@@ -114,9 +115,11 @@ def grad(self, values=None):
114115
if values:
115116
grad.update(values)
116117

117-
for v, g in self(grad).items():
118+
jac = self(grad)
119+
for v, g in jac.items():
118120
grad[v] = grad.get(v, 0) + g
119121

122+
grad.pop(FactorValue)
120123
return grad
121124

122125

@@ -138,13 +141,18 @@ def factor_out(self):
138141

139142
class VectorJacobianProduct(AbstractJacobian):
140143
def __init__(
141-
self, factor_out, vjp: Callable, *variables: Variable, out_shapes=None
144+
self, factor_out, vjp: Callable, *args: Variable, out_shapes=None
142145
):
143146
self.factor_out = factor_out
144147
self.vjp = vjp
145-
self._variables = variables
148+
self._args = args
149+
self._variables = tuple(v for v, in nested_filter(is_variable, args))
146150
self.out_shapes = out_shapes
147151

152+
@property
153+
def args(self):
154+
return self._args
155+
148156
@property
149157
def variables(self):
150158
return self._variables
@@ -172,7 +180,7 @@ def _get_cotangent(self, values):
172180
def __call__(self, values: Union[VariableData, FactorValue]) -> VariableData:
173181
v = self._get_cotangent(values)
174182
grads = self.vjp(v)
175-
return VariableData(zip(self.variables, grads))
183+
return to_variabledata(self.args, grads)
176184

177185
__rmul__ = __call__
178186

autofit/graphical/laplace/newton.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def diag_sr1_update(
7070
d = dzk.dot(dzk)
7171
if d > tol * dk.norm() ** 2 * zk.norm() ** 2:
7272
alpha = -zk.dot(dk) / d
73-
Bk = Bk.diagonalupdate(alpha * (zk ** 2))
73+
Bk = Bk.diagonalupdate((zk ** 2) * alpha)
7474

7575
state1.hessian = Bk
7676
return state1
@@ -93,7 +93,7 @@ def diag_sr1_update_(
9393
else:
9494
alpha[v] = 0.0
9595

96-
Bk = Bk.diagonalupdate(alpha * (zk ** 2))
96+
Bk = Bk.diagonalupdate((zk ** 2) * alpha)
9797

9898
state1.hessian = Bk
9999
return state1
@@ -184,7 +184,7 @@ def diag_quasi_deterministic_update(
184184
zk2 = zk ** 2
185185
zk4 = (zk2 ** 2).sum()
186186
alpha = (dk.dot(Bxk.dot(dk)) - zk.dot(Bzk.dot(zk))) / zk4
187-
state1.det_hessian = Bzk.diagonalupdate(float(alpha) * zk2)
187+
state1.det_hessian = Bzk.diagonalupdate(zk2 * alpha)
188188

189189
return state1
190190

0 commit comments

Comments
 (0)