Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve join_nonshared_inputs documentation #6216

Merged
merged 22 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ca8862a
adding type hints, using keyword args, and start of example based on …
wd60622 Oct 14, 2022
d70bb6f
adding second example and some more to the docs
wd60622 Oct 19, 2022
564ab09
add the suggestions from PR
wd60622 Oct 20, 2022
ada261a
Merge branch 'pymc-devs:main' into aesaraf_join_shared_input_doc
wd60622 Oct 20, 2022
a526a03
first example did not render
wd60622 Oct 20, 2022
f9a165a
fix typo
wd60622 Oct 21, 2022
c4be323
switch to numpy docstyle. comment in example. fix return
wd60622 Oct 26, 2022
117b929
Merge branch 'pymc-devs:main' into aesaraf_join_shared_input_doc
wd60622 Oct 26, 2022
b3ca27d
forgotten dict
wd60622 Oct 27, 2022
7e19350
numpy docstring standards
wd60622 Oct 27, 2022
fe95d58
Merge branch 'pymc-devs:main' into aesaraf_join_shared_input_doc
wd60622 Oct 27, 2022
2923927
adding type hints, using keyword args, and start of example based on …
wd60622 Oct 14, 2022
3a8b3fb
adding second example and some more to the docs
wd60622 Oct 19, 2022
70601f5
add the suggestions from PR
wd60622 Oct 20, 2022
acc4f6d
first example did not render
wd60622 Oct 20, 2022
fb0a141
fix typo
wd60622 Oct 21, 2022
4f9d687
switch to numpy docstyle. comment in example. fix return
wd60622 Oct 26, 2022
a1dc3c4
forgotten dict
wd60622 Oct 27, 2022
7065982
numpy docstring standards
wd60622 Oct 27, 2022
f37dbaa
Merge branch 'aesaraf_join_shared_input_doc' of https://github.com/wd…
wd60622 Nov 1, 2022
a50be1f
switching away from Point
wd60622 Nov 1, 2022
8360f86
Merge branch 'pymc-devs:main' into aesaraf_join_shared_input_doc
wd60622 Nov 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 130 additions & 30 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.sharedvar import SharedVariable, TensorSharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant, TensorVariable

Expand Down Expand Up @@ -535,55 +535,155 @@ def make_shared_replacements(point, vars, model):

def join_nonshared_inputs(
point: Dict[str, np.ndarray],
xs: List[TensorVariable],
vars: List[TensorVariable],
shared,
make_shared: bool = False,
):
outputs: List[TensorVariable],
inputs: List[TensorVariable],
shared_inputs: Optional[Dict[TensorVariable, TensorSharedVariable]] = None,
make_inputs_shared: bool = False,
) -> Tuple[List[TensorVariable], TensorVariable]:
"""
Takes a list of Aesara Variables and joins their non shared inputs into a single input.
Create new outputs and input TensorVariables where the non-shared inputs are joined
in a single raveled vector input.

Parameters
----------
point: a sample point
xs: list of Aesara tensors
vars: list of variables to join
point : dict of {str : array_like}
Dictionary that maps each input variable name to a numerical variable. The values
are used to extract the shape of each input variable to establish a correct
mapping between joined and original inputs. The shape of each variable is
assumed to be fixed.
outputs : list of TensorVariable
List of output TensorVariables whose non-shared inputs will be replaced
by a joined vector input.
inputs : list of TensorVariable
List of input TensorVariables which will be replaced by a joined vector input.
shared_inputs : dict of {TensorVariable : TensorSharedVariable}, optional
Dict of TensorVariable and their associated TensorSharedVariable in
subgraph replacement.
make_inputs_shared : bool, default False
Whether to make the joined vector input a shared variable.

Returns
-------
tensors, inarray
tensors: list of same tensors but with inarray as input
inarray: vector of inputs
new_outputs : list of TensorVariable
List of new outputs `outputs` TensorVariables that depend on `joined_inputs` and new shared variables as inputs.
joined_inputs : TensorVariable
Joined input vector TensorVariable for the `new_outputs`

Examples
--------
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
Join the inputs of a simple Aesara graph.

.. code-block:: python

import aesara.tensor as at
import numpy as np

from pymc.aesaraf import join_nonshared_inputs

# Original non-shared inputs
x = at.scalar("x")
y = at.vector("y")
# Original output
out = x + y
print(out.eval({x: np.array(1), y: np.array([1, 2, 3])})) # [2, 3, 4]

# New output and inputs
[new_out], joined_inputs = join_nonshared_inputs(
point={ # Only shapes matter
"x": np.zeros(()),
"y": np.zeros(3),
},
outputs=[out],
inputs=[x, y],
)
print(new_out.eval({
joined_inputs: np.array([1, 1, 2, 3]),
})) # [2, 3, 4]

Join the input value variables of a model logp.

.. code-block:: python

import pymc as pm

with pm.Model() as model:
mu_pop = pm.Normal("mu_pop")
sigma_pop = pm.HalfNormal("sigma_pop")
mu = pm.Normal("mu", mu_pop, sigma_pop, shape=(3, ))

y = pm.Normal("y", mu, 1.0, observed=[0, 1, 2])

print(model.compile_logp()({
"mu_pop": 0,
"sigma_pop_log__": 1,
"mu": [0, 1, 2],
})) # -12.691227342634292

initial_point = model.initial_point()
inputs = model.value_vars

[logp], joined_inputs = join_nonshared_inputs(
point=initial_point,
outputs=[model.logp()],
inputs=inputs,
)

print(logp.eval({
joined_inputs: [0, 1, 0, 1, 2],
})) # -12.691227342634292

Same as above but with the `mu_pop` value variable being shared.

.. code-block:: python

from aesara import shared

mu_pop_input, *other_inputs = inputs
shared_mu_pop_input = shared(0.0)

[logp], other_joined_inputs = join_nonshared_inputs(
point=initial_point,
outputs=[model.logp()],
inputs=other_inputs,
shared_inputs={
mu_pop_input: shared_mu_pop_input
},
)

print(logp.eval({
other_joined_inputs: [1, 0, 1, 2],
})) # -12.691227342634292
"""
if not vars:
raise ValueError("Empty list of variables.")
if not inputs:
raise ValueError("Empty list of input variables.")

joined = at.concatenate([var.ravel() for var in vars])
raveled_inputs = at.concatenate([var.ravel() for var in inputs])

if not make_shared:
tensor_type = joined.type
inarray = tensor_type("inarray")
if not make_inputs_shared:
tensor_type = raveled_inputs.type
joined_inputs = tensor_type("joined_inputs")
else:
if point is None:
raise ValueError("A point is required when `make_shared` is True")
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
joined_values = np.concatenate([point[var.name].ravel() for var in vars])
inarray = aesara.shared(joined_values, "inarray")
joined_values = np.concatenate([point[var.name].ravel() for var in inputs])
joined_inputs = aesara.shared(joined_values, "joined_inputs")

if aesara.config.compute_test_value != "off":
inarray.tag.test_value = joined.tag.test_value
joined_inputs.tag.test_value = raveled_inputs.tag.test_value

replace = {}
replace: Dict[TensorVariable, TensorVariable] = {}
last_idx = 0
for var in vars:
for var in inputs:
shape = point[var.name].shape
arr_len = np.prod(shape, dtype=int)
replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
last_idx += arr_len

replace.update(shared)
if shared_inputs is not None:
replace.update(shared_inputs)

xs_special = [aesara.clone_replace(x, replace, rebuild_strict=False) for x in xs]
return xs_special, inarray
new_outputs = [
aesara.clone_replace(output, replace, rebuild_strict=False) for output in outputs
]
return new_outputs, joined_inputs


class PointFunc:
Expand Down
10 changes: 6 additions & 4 deletions pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,11 @@ def _logp_forw(point, out_vars, in_vars, shared):

Parameters
----------
out_vars: List
out_vars : list
containing :class:`pymc.Distribution` for the output variables
in_vars: List
in_vars : list
containing :class:`pymc.Distribution` for the input variables
shared: List
shared : list
containing :class:`aesara.tensor.Tensor` for depended shared data
"""

Expand All @@ -602,7 +602,9 @@ def _logp_forw(point, out_vars, in_vars, shared):
out_vars = clone_replace(out_vars, replace_int_input, rebuild_strict=False)
in_vars = new_in_vars

out_list, inarray0 = join_nonshared_inputs(point, out_vars, in_vars, shared)
out_list, inarray0 = join_nonshared_inputs(
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
)
f = compile_pymc([inarray0], out_list[0])
f.trust_input = True
return f
13 changes: 11 additions & 2 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple

import aesara
import numpy as np
import numpy.random as nr
import scipy.linalg
import scipy.special

from aesara import tensor as at
from aesara.graph.fg import MissingInputError
from aesara.tensor.random.basic import BernoulliRV, CategoricalRV

Expand Down Expand Up @@ -1052,8 +1054,15 @@ def sample_except(limit, excluded):
return candidate


def delta_logp(point, logp, vars, shared):
[logp0], inarray0 = join_nonshared_inputs(point, [logp], vars, shared)
def delta_logp(
point: Dict[str, np.ndarray],
logp: at.TensorVariable,
vars: List[at.TensorVariable],
shared: Dict[at.TensorVariable, at.sharedvar.TensorSharedVariable],
) -> aesara.compile.Function:
[logp0], inarray0 = join_nonshared_inputs(
point=point, outputs=[logp], inputs=vars, shared_inputs=shared
)

tensor_type = inarray0.type
inarray1 = tensor_type("inarray1")
Expand Down