Skip to content

Commit

Permalink
Improve join_nonshared_inputs documentation (#6216)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Nov 2, 2022
1 parent 23c4834 commit b0d1066
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 36 deletions.
160 changes: 130 additions & 30 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.sharedvar import SharedVariable, TensorSharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant, TensorVariable

Expand Down Expand Up @@ -535,55 +535,155 @@ def make_shared_replacements(point, vars, model):

def join_nonshared_inputs(
point: Dict[str, np.ndarray],
xs: List[TensorVariable],
vars: List[TensorVariable],
shared,
make_shared: bool = False,
):
outputs: List[TensorVariable],
inputs: List[TensorVariable],
shared_inputs: Optional[Dict[TensorVariable, TensorSharedVariable]] = None,
make_inputs_shared: bool = False,
) -> Tuple[List[TensorVariable], TensorVariable]:
"""
Takes a list of Aesara Variables and joins their non shared inputs into a single input.
Create new outputs and input TensorVariables where the non-shared inputs are joined
in a single raveled vector input.
Parameters
----------
point: a sample point
xs: list of Aesara tensors
vars: list of variables to join
point : dict of {str : array_like}
Dictionary that maps each input variable name to a numerical variable. The values
are used to extract the shape of each input variable to establish a correct
mapping between joined and original inputs. The shape of each variable is
assumed to be fixed.
outputs : list of TensorVariable
List of output TensorVariables whose non-shared inputs will be replaced
by a joined vector input.
inputs : list of TensorVariable
List of input TensorVariables which will be replaced by a joined vector input.
shared_inputs : dict of {TensorVariable : TensorSharedVariable}, optional
Dict of TensorVariable and their associated TensorSharedVariable in
subgraph replacement.
make_inputs_shared : bool, default False
Whether to make the joined vector input a shared variable.
Returns
-------
tensors, inarray
tensors: list of same tensors but with inarray as input
inarray: vector of inputs
new_outputs : list of TensorVariable
List of new outputs `outputs` TensorVariables that depend on `joined_inputs` and new shared variables as inputs.
joined_inputs : TensorVariable
Joined input vector TensorVariable for the `new_outputs`
Examples
--------
Join the inputs of a simple Aesara graph.
.. code-block:: python
import aesara.tensor as at
import numpy as np
from pymc.aesaraf import join_nonshared_inputs
# Original non-shared inputs
x = at.scalar("x")
y = at.vector("y")
# Original output
out = x + y
print(out.eval({x: np.array(1), y: np.array([1, 2, 3])})) # [2, 3, 4]
# New output and inputs
[new_out], joined_inputs = join_nonshared_inputs(
point={ # Only shapes matter
"x": np.zeros(()),
"y": np.zeros(3),
},
outputs=[out],
inputs=[x, y],
)
print(new_out.eval({
joined_inputs: np.array([1, 1, 2, 3]),
})) # [2, 3, 4]
Join the input value variables of a model logp.
.. code-block:: python
import pymc as pm
with pm.Model() as model:
mu_pop = pm.Normal("mu_pop")
sigma_pop = pm.HalfNormal("sigma_pop")
mu = pm.Normal("mu", mu_pop, sigma_pop, shape=(3, ))
y = pm.Normal("y", mu, 1.0, observed=[0, 1, 2])
print(model.compile_logp()({
"mu_pop": 0,
"sigma_pop_log__": 1,
"mu": [0, 1, 2],
})) # -12.691227342634292
initial_point = model.initial_point()
inputs = model.value_vars
[logp], joined_inputs = join_nonshared_inputs(
point=initial_point,
outputs=[model.logp()],
inputs=inputs,
)
print(logp.eval({
joined_inputs: [0, 1, 0, 1, 2],
})) # -12.691227342634292
Same as above but with the `mu_pop` value variable being shared.
.. code-block:: python
from aesara import shared
mu_pop_input, *other_inputs = inputs
shared_mu_pop_input = shared(0.0)
[logp], other_joined_inputs = join_nonshared_inputs(
point=initial_point,
outputs=[model.logp()],
inputs=other_inputs,
shared_inputs={
mu_pop_input: shared_mu_pop_input
},
)
print(logp.eval({
other_joined_inputs: [1, 0, 1, 2],
})) # -12.691227342634292
"""
if not vars:
raise ValueError("Empty list of variables.")
if not inputs:
raise ValueError("Empty list of input variables.")

joined = at.concatenate([var.ravel() for var in vars])
raveled_inputs = at.concatenate([var.ravel() for var in inputs])

if not make_shared:
tensor_type = joined.type
inarray = tensor_type("inarray")
if not make_inputs_shared:
tensor_type = raveled_inputs.type
joined_inputs = tensor_type("joined_inputs")
else:
if point is None:
raise ValueError("A point is required when `make_shared` is True")
joined_values = np.concatenate([point[var.name].ravel() for var in vars])
inarray = aesara.shared(joined_values, "inarray")
joined_values = np.concatenate([point[var.name].ravel() for var in inputs])
joined_inputs = aesara.shared(joined_values, "joined_inputs")

if aesara.config.compute_test_value != "off":
inarray.tag.test_value = joined.tag.test_value
joined_inputs.tag.test_value = raveled_inputs.tag.test_value

replace = {}
replace: Dict[TensorVariable, TensorVariable] = {}
last_idx = 0
for var in vars:
for var in inputs:
shape = point[var.name].shape
arr_len = np.prod(shape, dtype=int)
replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
last_idx += arr_len

replace.update(shared)
if shared_inputs is not None:
replace.update(shared_inputs)

xs_special = [aesara.clone_replace(x, replace, rebuild_strict=False) for x in xs]
return xs_special, inarray
new_outputs = [
aesara.clone_replace(output, replace, rebuild_strict=False) for output in outputs
]
return new_outputs, joined_inputs


class PointFunc:
Expand Down
10 changes: 6 additions & 4 deletions pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,11 @@ def _logp_forw(point, out_vars, in_vars, shared):
Parameters
----------
out_vars: List
out_vars : list
containing :class:`pymc.Distribution` for the output variables
in_vars: List
in_vars : list
containing :class:`pymc.Distribution` for the input variables
shared: List
shared : list
containing :class:`aesara.tensor.Tensor` for depended shared data
"""

Expand All @@ -602,7 +602,9 @@ def _logp_forw(point, out_vars, in_vars, shared):
out_vars = clone_replace(out_vars, replace_int_input, rebuild_strict=False)
in_vars = new_in_vars

out_list, inarray0 = join_nonshared_inputs(point, out_vars, in_vars, shared)
out_list, inarray0 = join_nonshared_inputs(
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
)
f = compile_pymc([inarray0], out_list[0])
f.trust_input = True
return f
13 changes: 11 additions & 2 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple

import aesara
import numpy as np
import numpy.random as nr
import scipy.linalg
import scipy.special

from aesara import tensor as at
from aesara.graph.fg import MissingInputError
from aesara.tensor.random.basic import BernoulliRV, CategoricalRV

Expand Down Expand Up @@ -1052,8 +1054,15 @@ def sample_except(limit, excluded):
return candidate


def delta_logp(point, logp, vars, shared):
[logp0], inarray0 = join_nonshared_inputs(point, [logp], vars, shared)
def delta_logp(
point: Dict[str, np.ndarray],
logp: at.TensorVariable,
vars: List[at.TensorVariable],
shared: Dict[at.TensorVariable, at.sharedvar.TensorSharedVariable],
) -> aesara.compile.Function:
[logp0], inarray0 = join_nonshared_inputs(
point=point, outputs=[logp], inputs=vars, shared_inputs=shared
)

tensor_type = inarray0.type
inarray1 = tensor_type("inarray1")
Expand Down

0 comments on commit b0d1066

Please sign in to comment.