Skip to content

Commit

Permalink
Fix latex representation for SharedVariable inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 3, 2022
1 parent 411fc0e commit 4262bdc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
27 changes: 17 additions & 10 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

from typing import Union

from aesara.graph.basic import walk
from aesara.compile import SharedVariable
from aesara.graph.basic import Constant, walk
from aesara.tensor.basic import TensorVariable, Variable
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import RandomVariable
from aesara.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
from aesara.tensor.var import TensorConstant

from pymc.model import Model

Expand Down Expand Up @@ -163,7 +163,7 @@ def _is_potential_or_determinstic(var: Variable) -> bool:
# in case other code overrides str_repr, fallback
return False

if isinstance(var, TensorConstant):
if isinstance(var, (Constant, SharedVariable)):
return _str_for_constant(var, formatting)
elif isinstance(
var.owner.op, (RandomVariable, SymbolicRandomVariable)
Expand All @@ -189,15 +189,22 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
return _str


def _str_for_constant(var: TensorConstant, formatting: str) -> str:
if len(var.data.shape) == 0:
return f"{var.data:.3g}"
elif len(var.data.shape) == 1 and var.data.shape[0] == 1:
return f"{var.data[0]:.3g}"
def _str_for_constant(var: Union[Constant, SharedVariable], formatting: str) -> str:
if isinstance(var, Constant):
var_data = var.data
var_type = "constant"
else:
var_data = var.get_value()
var_type = "shared"

if len(var_data.shape) == 0:
return f"{var_data:.3g}"
elif len(var_data.shape) == 1 and var_data.shape[0] == 1:
return f"{var_data[0]:.3g}"
elif "latex" in formatting:
return r"\text{<constant>}"
return rf"\text{{<{var_type}>}}"
else:
return r"<constant>"
return rf"<{var_type}>"


def _str_for_expression(var: Variable, formatting: str) -> str:
Expand Down
42 changes: 42 additions & 0 deletions pymc/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,48 @@ def setup_class(self):
}


class TestData(BaseTestStrAndLatexRepr):
def setup_class(self):
with Model() as self.model:
import pymc as pm

with pm.Model() as model:
a = pm.Normal("a", pm.MutableData("a_data", (2,)))
b = pm.Normal("b", pm.MutableData("b_data", (2, 3)))
c = pm.Normal("c", pm.ConstantData("c_data", (2,)))
d = pm.Normal("d", pm.ConstantData("d_data", (2, 3)))

self.distributions = [a, b, c, d]
# tuples of (formatting, include_params)
self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)]
self.expected = {
("plain", True): [
r"a ~ N(2, 1)",
r"b ~ N(<shared>, 1)",
r"c ~ N(2, 1)",
r"d ~ N(<constant>, 1)",
],
("plain", False): [
r"a ~ N",
r"b ~ N",
r"c ~ N",
r"d ~ N",
],
("latex", True): [
r"$\text{a} \sim \operatorname{N}(2,~1)$",
r"$\text{b} \sim \operatorname{N}(\text{<shared>},~1)$",
r"$\text{c} \sim \operatorname{N}(2,~1)$",
r"$\text{d} \sim \operatorname{N}(\text{<constant>},~1)$",
],
("latex", False): [
r"$\text{a} \sim \operatorname{N}$",
r"$\text{b} \sim \operatorname{N}$",
r"$\text{c} \sim \operatorname{N}$",
r"$\text{d} \sim \operatorname{N}$",
],
}


def test_model_latex_repr_three_levels_model():
with Model() as censored_model:
mu = Normal("mu", 0.0, 5.0)
Expand Down

0 comments on commit 4262bdc

Please sign in to comment.