|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import itertools |
| 16 | +import warnings |
16 | 17 |
|
17 | 18 | from typing import Union |
18 | 19 |
|
|
24 | 25 |
|
25 | 26 | from pymc.model import Model |
26 | 27 |
|
27 | | -# from pymc.distributions.discrete import UnmeasurableConstantRV |
28 | | - |
29 | 28 | __all__ = [ |
30 | 29 | "str_for_dist", |
31 | 30 | "str_for_symbolic_dist", |
@@ -65,7 +64,15 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params) |
65 | 64 | if var.name: |
66 | 65 | return var.name |
67 | 66 | if isinstance(var, TensorConstant): |
68 | | - return _str_for_constant(var, formatting, print_vector=True) |
| 67 | + if len(var.data.shape) > 1: |
| 68 | + raise NotImplementedError |
| 69 | + try: |
| 70 | + if var.data.shape[0] > 1: |
| 71 | + # weights in mixture model |
| 72 | + return "[" + ",".join([str(weight) for weight in var.data]) + "]" |
| 73 | + except IndexError: |
| 74 | + # just a scalar |
| 75 | + return _str_for_constant(var, formatting) |
69 | 76 | if isinstance(var.owner.op, MakeVector): |
70 | 77 | # psi in some zero inflated distribution |
71 | 78 | return dispatch_comp_str(var.owner.inputs[1]) |
@@ -102,11 +109,17 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params) |
102 | 109 | dist_parameters = rv.owner.inputs[1:] |
103 | 110 |
|
104 | 111 | elif "Censored" in rv.owner.op._print_name[0]: |
105 | | - dist_parameters = rv.owner.inputs[2:] |
| 112 | + dist_parameters = rv.owner.inputs |
106 | 113 | else: |
107 | 114 | # Latex representation for the SymbolicDistribution has not been implemented. |
108 | 115 | # Hoping for the best here! |
109 | 116 | dist_parameters = rv.owner.inputs[2:] |
| 117 | + warnings.warn( |
| 118 | + "Latex representation for this SymbolicDistribution has not been implemented. " |
| 119 | + "Please have a look at str_for_symbolic_dist in pymc/printing.py", |
| 120 | + FutureWarning, |
| 121 | + stacklevel=2, |
| 122 | + ) |
110 | 123 |
|
111 | 124 | dist_args = [ |
112 | 125 | dispatch_comp_str(dist_para, formatting=formatting, include_params=include_params) |
@@ -222,13 +235,11 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str: |
222 | 235 | return _str |
223 | 236 |
|
224 | 237 |
|
225 | | -def _str_for_constant(var: TensorConstant, formatting: str, print_vector: bool = False) -> str: |
| 238 | +def _str_for_constant(var: TensorConstant, formatting: str) -> str: |
226 | 239 | if len(var.data.shape) == 0: |
227 | 240 | return f"{var.data:.3g}" |
228 | 241 | elif len(var.data.shape) == 1 and var.data.shape[0] == 1: |
229 | 242 | return f"{var.data[0]:.3g}" |
230 | | - elif len(var.data.shape) == 1 and print_vector: |
231 | | - return "[" + ", ".join([f"{const:.3g}" for const in var.data]) + "]" |
232 | 243 | elif "latex" in formatting: |
233 | 244 | return r"\text{<constant>}" |
234 | 245 | else: |
|
0 commit comments