Skip to content

Commit 929a3b5

Browse files
Mixture dists look alright
1 parent 0fae5ea commit 929a3b5

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

pymc/distributions/distribution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,9 @@ def __new__(
484484
initval=initval,
485485
)
486486

487+
# TODO: if rv_out is a cloned variable, the line below wouldn't work
487488
set_print_name(cls, rv_out)
489+
488490
rv_out.str_repr = types.MethodType(str_for_symbolic_dist, rv_out)
489491
rv_out._repr_latex_ = types.MethodType(
490492
functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out

pymc/printing.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,49 @@ def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params:
5959
def str_for_symbolic_dist(
6060
rv: TensorVariable, formatting: str = "plain", include_params: bool = True
6161
) -> str:
62+
def dispatch_comp_str(var, formatting=formatting, include_params=include_params):
63+
if var.name:
64+
return str_for_dist(var, formatting=formatting, include_params=include_params)
65+
if isinstance(var, TensorConstant):
66+
return _str_for_constant(var, formatting)
67+
if var.owner.op.name == "constant":
68+
return "hello"
69+
70+
# else it's a Mixture component initialized by the .dist() API
71+
72+
dist_args = ", ".join(
73+
[_str_for_input_var(x, formatting=formatting) for x in var.owner.inputs[3:]]
74+
)
75+
comp_name = var.owner.op.name.capitalize()
76+
77+
if "latex" in formatting:
78+
comp_name = r"\text{" + _latex_escape(comp_name) + "}"
6279

63-
# code would look something like this:
64-
# if "ZeroInflated" in rv.owner.op._print_name[0]:
65-
# return
66-
# if "Mixture" in rv.owner.op._print_name[0]:
67-
# return
80+
return f"{comp_name}({dist_args})"
6881

69-
# below is copy-pasted from str_for_dist
7082
if include_params:
71-
# first 3 args are always (rng, size, dtype), rest is relevant for distribution
72-
dist_args = [_str_for_input_var(x, formatting=formatting) for x in rv.owner.inputs[3:]]
83+
if "ZeroInflated" in rv.owner.op._print_name[0]:
84+
start_idx_para = 2
85+
elif "Mixture" in rv.owner.op._print_name[0]:
86+
start_idx_para = 1
87+
88+
if len(rv.owner.inputs) == 3:
89+
# is a single component!
90+
# (rng, weights, single_component)
91+
pass
92+
93+
elif "Censored" in rv.owner.op._print_name[0]:
94+
start_idx_para = 2
95+
else:
96+
raise ValueError(
97+
"Latex printing not yet implemented for this SymbolicDistribution\n"
98+
"Please update the str_for_symbolic_dist in pymc/printing.py file."
99+
)
100+
101+
dist_args = [
102+
dispatch_comp_str(dist_para, formatting=formatting, include_params=include_params)
103+
for dist_para in rv.owner.inputs[start_idx_para:]
104+
]
73105

74106
print_name = rv.name if rv.name is not None else "<unnamed>"
75107
if "latex" in formatting:
@@ -165,7 +197,14 @@ def _is_potential_or_determinstic(var: Variable) -> bool:
165197

166198

167199
def _str_for_input_rv(var: Variable, formatting: str) -> str:
168-
_str = var.name if var.name is not None else "<unnamed>"
200+
201+
if var.name:
202+
_str = var.name
203+
elif var.owner.op.name:
204+
_str = var.owner.op.name.capitalize()
205+
else:
206+
_str = "<unnamed>"
207+
169208
if "latex" in formatting:
170209
return r"\text{" + _latex_escape(_str) + "}"
171210
else:

0 commit comments

Comments
 (0)