55import yaml
66from matplotlib import pyplot as plt
77from mu_ppl import viz
8- from mu_ppl .distributions import Categorical
98
109from ._version import version
1110from .pdl import InterpreterConfig
1211from .pdl_ast import PdlLocationType , Program , ScopeType , get_default_model_parameters
12+ from .pdl_distributions import Categorical
1313from .pdl_inference import (
1414 infer_importance_sampling ,
1515 infer_importance_sampling_parallel ,
16- infer_rejection ,
17- infer_rejection_parallel ,
16+ infer_rejection_sampling ,
17+ infer_rejection_sampling_parallel ,
1818 infer_smc ,
1919 infer_smc_parallel ,
2020)
@@ -42,7 +42,7 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
4242 ppdl_config : Optional [PpdlConfig ] = None ,
4343 scope : Optional [ScopeType | dict [str , Any ]] = None ,
4444 loc : Optional [PdlLocationType ] = None ,
45- output : Literal ["result" , "all" ] = "result" ,
45+ # output: Literal["result", "all"] = "result",
4646) -> Categorical [Any ]:
4747 ppdl_config = ppdl_config or PpdlConfig ()
4848
@@ -56,46 +56,42 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
5656 config ["batch" ] = 1
5757 config ["event_loop" ] = _LOOP
5858
59+ dist : Categorical [Any ]
5960 match algo :
6061 case "is" :
6162 dist = infer_importance_sampling (
62- prog , config , scope , loc , output , num_particles = num_particles
63+ prog , config , scope , loc , num_particles = num_particles
6364 )
6465 case "parallel-is" :
6566 dist = infer_importance_sampling_parallel (
6667 prog ,
6768 config ,
6869 scope ,
6970 loc ,
70- output ,
7171 num_particles = num_particles ,
7272 max_workers = max_workers ,
7373 )
7474 case "smc" :
75- dist = infer_smc (
76- prog , config , scope , loc , output , num_particles = num_particles
77- )
75+ dist = infer_smc (prog , config , scope , loc , num_particles = num_particles )
7876 case "parallel-smc" :
7977 dist = infer_smc_parallel (
8078 prog ,
8179 config ,
8280 scope ,
8381 loc ,
84- output ,
8582 num_particles = num_particles ,
8683 max_workers = max_workers ,
8784 )
8885 case "rejection" :
89- dist = infer_rejection (
90- prog , config , scope , loc , output , num_samples = num_particles
86+ dist = infer_rejection_sampling (
87+ prog , config , scope , loc , num_samples = num_particles
9188 )
9289 case "parallel-rejection" :
93- dist = infer_rejection_parallel (
90+ dist = infer_rejection_sampling_parallel (
9491 prog ,
9592 config ,
9693 scope ,
9794 loc ,
98- output ,
9995 num_samples = num_particles ,
10096 max_workers = max_workers ,
10197 )
@@ -110,10 +106,10 @@ def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-argume
110106 ppdl_config : Optional [PpdlConfig ] = None ,
111107 scope : Optional [ScopeType | dict [str , Any ]] = None ,
112108 loc : Optional [PdlLocationType ] = None ,
113- output : Literal ["result" , "all" ] = "result" ,
109+ # output: Literal["result", "all"] = "result",
114110) -> Any :
115111 program = parse_dict (prog )
116- result = exec_program (program , config , ppdl_config , scope , loc , output )
112+ result = exec_program (program , config , ppdl_config , scope , loc )
117113 return result
118114
119115
@@ -122,10 +118,10 @@ def exec_str(
122118 config : Optional [InterpreterConfig ] = None ,
123119 ppdl_config : Optional [PpdlConfig ] = None ,
124120 scope : Optional [ScopeType | dict [str , Any ]] = None ,
125- output : Literal ["result" , "all" ] = "result" ,
121+ # output: Literal["result", "all"] = "result",
126122) -> Any :
127123 program , loc = parse_str (prog )
128- result = exec_program (program , config , ppdl_config , scope , loc , output )
124+ result = exec_program (program , config , ppdl_config , scope , loc )
129125 return result
130126
131127
@@ -134,14 +130,14 @@ def exec_file(
134130 config : Optional [InterpreterConfig ] = None ,
135131 ppdl_config : Optional [PpdlConfig ] = None ,
136132 scope : Optional [ScopeType | dict [str , Any ]] = None ,
137- output : Literal ["result" , "all" ] = "result" ,
133+ # output: Literal["result", "all"] = "result",
138134) -> Any :
139135 program , loc = parse_file (prog )
140136 if config is None :
141137 config = InterpreterConfig ()
142138 if config .get ("cwd" ) is None :
143139 config ["cwd" ] = Path (prog ).parent
144- result = exec_program (program , config , ppdl_config , scope , loc , output )
140+ result = exec_program (program , config , ppdl_config , scope , loc )
145141 return result
146142
147143
0 commit comments