@@ -59,17 +59,49 @@ def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params:
5959def str_for_symbolic_dist (
6060 rv : TensorVariable , formatting : str = "plain" , include_params : bool = True
6161) -> str :
62+ def dispatch_comp_str (var , formatting = formatting , include_params = include_params ):
63+ if var .name :
64+ return str_for_dist (var , formatting = formatting , include_params = include_params )
65+ if isinstance (var , TensorConstant ):
66+ return _str_for_constant (var , formatting )
67+ if var .owner .op .name == "constant" :
68+ return "hello"
69+
70+ # else it's a Mixture component initialized by the .dist() API
71+
72+ dist_args = ", " .join (
73+ [_str_for_input_var (x , formatting = formatting ) for x in var .owner .inputs [3 :]]
74+ )
75+ comp_name = var .owner .op .name .capitalize ()
76+
77+ if "latex" in formatting :
78+ comp_name = r"\text{" + _latex_escape (comp_name ) + "}"
6279
63- # code would look something like this:
64- # if "ZeroInflated" in rv.owner.op._print_name[0]:
65- # return
66- # if "Mixture" in rv.owner.op._print_name[0]:
67- # return
80+ return f"{ comp_name } ({ dist_args } )"
6881
69- # below is copy-pasted from str_for_dist
7082 if include_params :
71- # first 3 args are always (rng, size, dtype), rest is relevant for distribution
72- dist_args = [_str_for_input_var (x , formatting = formatting ) for x in rv .owner .inputs [3 :]]
83+ if "ZeroInflated" in rv .owner .op ._print_name [0 ]:
84+ start_idx_para = 2
85+ elif "Mixture" in rv .owner .op ._print_name [0 ]:
86+ start_idx_para = 1
87+
88+ if len (rv .owner .inputs ) == 3 :
89+ # is a single component!
90+ # (rng, weights, single_component)
91+ pass
92+
93+ elif "Censored" in rv .owner .op ._print_name [0 ]:
94+ start_idx_para = 2
95+ else :
96+ raise ValueError (
97+ "Latex printing not yet implemented for this SymbolicDistribution\n "
98+ "Please update the str_for_symbolic_dist in pymc/printing.py file."
99+ )
100+
101+ dist_args = [
102+ dispatch_comp_str (dist_para , formatting = formatting , include_params = include_params )
103+ for dist_para in rv .owner .inputs [start_idx_para :]
104+ ]
73105
74106 print_name = rv .name if rv .name is not None else "<unnamed>"
75107 if "latex" in formatting :
@@ -165,7 +197,14 @@ def _is_potential_or_determinstic(var: Variable) -> bool:
165197
166198
167199def _str_for_input_rv (var : Variable , formatting : str ) -> str :
168- _str = var .name if var .name is not None else "<unnamed>"
200+
201+ if var .name :
202+ _str = var .name
203+ elif var .owner .op .name :
204+ _str = var .owner .op .name .capitalize ()
205+ else :
206+ _str = "<unnamed>"
207+
169208 if "latex" in formatting :
170209 return r"\text{" + _latex_escape (_str ) + "}"
171210 else :
0 commit comments