1717from typing import Union
1818
1919from aesara .graph .basic import walk
20- from aesara .tensor .basic import TensorVariable , Variable
20+ from aesara .tensor .basic import MakeVector , TensorVariable , Variable
2121from aesara .tensor .elemwise import DimShuffle
2222from aesara .tensor .random .basic import RandomVariable
2323from aesara .tensor .var import TensorConstant
2424
2525from pymc .model import Model
2626
27+ # from pymc.distributions.discrete import UnmeasurableConstantRV
28+
2729__all__ = [
2830 "str_for_dist" ,
2931 "str_for_symbolic_dist" ,
@@ -61,11 +63,12 @@ def str_for_symbolic_dist(
6163) -> str :
6264 def dispatch_comp_str (var , formatting = formatting , include_params = include_params ):
6365 if var .name :
64- return str_for_dist ( var , formatting = formatting , include_params = include_params )
66+ return var . name
6567 if isinstance (var , TensorConstant ):
66- return _str_for_constant (var , formatting )
67- if var .owner .op .name == "constant" :
68- return "hello"
68+ return _str_for_constant (var , formatting , print_vector = True )
69+ if isinstance (var .owner .op , MakeVector ):
70+ # psi in some zero inflated distribution
71+ return dispatch_comp_str (var .owner .inputs [1 ])
6972
7073 # else it's a Mixture component initialized by the .dist() API
7174
@@ -81,27 +84,36 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params)
8184
8285 if include_params :
8386 if "ZeroInflated" in rv .owner .op ._print_name [0 ]:
84- start_idx_para = 2
87+ # position 2 is just a constant_rv{0, (0,), shape, False}.1
88+ assert rv .owner .inputs [2 ].owner .op .__class__ .__name__ == "UnmeasurableConstantRV"
89+ dist_parameters = [rv .owner .inputs [1 ]] + rv .owner .inputs [3 :]
90+
8591 elif "Mixture" in rv .owner .op ._print_name [0 ]:
86- start_idx_para = 1
8792
8893 if len (rv .owner .inputs ) == 3 :
8994 # is a single component!
9095 # (rng, weights, single_component)
91- pass
96+ rv .owner .op ._print_name = (
97+ f"{ rv .owner .inputs [2 ].owner .op .name .capitalize ()} Mixture" ,
98+ "\\ operatorname{" + f"{ rv .owner .inputs [2 ].owner .op .name .capitalize ()} Mixture}}" ,
99+ )
100+ dist_parameters = [rv .owner .inputs [1 ]] + rv .owner .inputs [2 ].owner .inputs [3 :]
101+ else :
102+ dist_parameters = rv .owner .inputs [1 :]
92103
93104 elif "Censored" in rv .owner .op ._print_name [0 ]:
94- start_idx_para = 2
105+ dist_parameters = rv . owner . inputs [ 2 :]
95106 else :
96107 # Latex representation for the SymbolicDistribution has not been implemented.
97108 # Hoping for the best here!
98- start_idx_para = 2
109+ dist_parameters = rv . owner . inputs [ 2 :]
99110
100111 dist_args = [
101112 dispatch_comp_str (dist_para , formatting = formatting , include_params = include_params )
102- for dist_para in rv . owner . inputs [ start_idx_para :]
113+ for dist_para in dist_parameters
103114 ]
104115
116+ # code below copied from str_for_dist
105117 print_name = rv .name if rv .name is not None else "<unnamed>"
106118 if "latex" in formatting :
107119 print_name = r"\text{" + _latex_escape (print_name ) + "}"
@@ -210,11 +222,13 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
210222 return _str
211223
212224
213- def _str_for_constant (var : TensorConstant , formatting : str ) -> str :
225+ def _str_for_constant (var : TensorConstant , formatting : str , print_vector : bool = False ) -> str :
214226 if len (var .data .shape ) == 0 :
215227 return f"{ var .data :.3g} "
216228 elif len (var .data .shape ) == 1 and var .data .shape [0 ] == 1 :
217229 return f"{ var .data [0 ]:.3g} "
230+ elif len (var .data .shape ) == 1 and print_vector :
231+ return "[" + ", " .join ([f"{ const :.3g} " for const in var .data ]) + "]"
218232 elif "latex" in formatting :
219233 return r"\text{<constant>}"
220234 else :
0 commit comments