Skip to content

Commit 1559490

Browse files
committed
Shifting print_value to pytensorf.py
1 parent b41805c commit 1559490

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

pymc/model/core.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from pytensor.compile import DeepCopyOp, get_mode
4141
from pytensor.compile.sharedvalue import SharedVariable
4242
from pytensor.graph.basic import Constant, Variable, graph_inputs
43-
from pytensor.printing import Print
4443
from pytensor.scalar import Cast
4544
from pytensor.tensor.elemwise import Elemwise
4645
from pytensor.tensor.random.op import RandomVariable
@@ -2246,12 +2245,3 @@ def normal_logp(value, mu, sigma):
22462245
)
22472246

22482247
return var
2249-
2250-
2251-
def print_value(var, name=None):
2252-
"""Print value of variable when it is computed during sampling.
2253-
This is likely to affect sampling performance.
2254-
"""
2255-
if name is None:
2256-
name = var.name
2257-
return Print(name)(var)

pymc/pytensorf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from pytensor.graph.fg import FunctionGraph
4242
from pytensor.graph.op import Op
43+
from pytensor.printing import Print
4344
from pytensor.scalar.basic import Cast
4445
from pytensor.scan.op import Scan
4546
from pytensor.tensor.basic import _as_tensor_variable
@@ -275,6 +276,15 @@ def floatX(X):
275276
return np.asarray(X, dtype=pytensor.config.floatX)
276277

277278

279+
def print_value(var, name=None):
280+
"""Print value of variable when it is computed during sampling.
281+
This is likely to affect sampling performance.
282+
"""
283+
if name is None:
284+
name = var.name
285+
return Print(name)(var)
286+
287+
278288
_conversion_map = {"float64": "int32", "float32": "int16", "float16": "int8", "float8": "int8"}
279289

280290

0 commit comments

Comments
 (0)