Skip to content

Commit

Permalink
Prefix sampling_jax.replace_shared_variables with underscore
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Jan 30, 2022
1 parent 4b19716 commit 73add2d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def assert_fn(value, *inps):
return assert_fn


def replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]:
def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]:
"""Replace shared variables in graph by their constant values
Raises
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_jaxified_graph(
) -> List[TensorVariable]:
"""Compile an Aesara graph into an optimized JAX function"""

graph = replace_shared_variables(outputs)
graph = _replace_shared_variables(outputs)

fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
# We need to add a Supervisor to the fgraph to be able to run the
Expand Down
6 changes: 3 additions & 3 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from pymc.sampling_jax import (
_get_log_likelihood,
_replace_shared_variables,
get_jaxified_graph,
get_jaxified_logp,
replace_shared_variables,
sample_numpyro_nuts,
)

Expand Down Expand Up @@ -95,13 +95,13 @@ def test_get_log_likelihood():
def test_replace_shared_variables():
x = aesara.shared(5, name="shared_x")

new_x = replace_shared_variables([x])
new_x = _replace_shared_variables([x])
shared_variables = [var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)]
assert not shared_variables

x.default_update = x + 1
with pytest.raises(ValueError, match="shared variables with default_update"):
replace_shared_variables([x])
_replace_shared_variables([x])


def test_get_jaxified_logp():
Expand Down

0 comments on commit 73add2d

Please sign in to comment.