Skip to content

Commit 94b0217

Browse files
committed
feat: probabilistic inference returns traces
Signed-off-by: Louis Mandel <lmandel@us.ibm.com>
1 parent 5927b75 commit 94b0217

File tree

3 files changed

+176
-108
lines changed

3 files changed

+176
-108
lines changed

src/pdl/pdl_distributions.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Adapted from mu-ppl: https://github.com/gbdrt/mu-ppl/blob/main/mu_ppl/distributions.py
2+
3+
from typing import Any, Generic, TypeVar
4+
import numpy as np
5+
import numpy.random as rand
6+
from scipy.special import logsumexp
7+
import seaborn as sns
8+
9+
T = TypeVar("T")
10+
11+
class Categorical(Generic[T]):
12+
"""
13+
Categorical distribution, i.e., finite support distribution where values can be of arbitrary type.
14+
"""
15+
16+
def __init__(self, tuples: list[tuple[T, float, list[Any]]]):
17+
"""
18+
Args:
19+
tuples: List of tuples (value, score, metadata), where the score is in log scale.
20+
"""
21+
self.values, self.logits, self.metadata = zip(*tuples)
22+
lse = logsumexp(self.logits)
23+
self.probs = np.exp(self.logits - lse) # type: ignore
24+
25+
def shrink(self) -> "Categorical[T]":
26+
"""
27+
Create an equivalent distribution without duplicated values.
28+
"""
29+
res: dict[T, tuple[float, list]] = {}
30+
for v, w, m in zip(self.values, self.probs, self.metadata):
31+
if v in res:
32+
w_v, m_v = res[v]
33+
res[v] = (w_v + w, m_v + m)
34+
else:
35+
res[v] = (w, m)
36+
return Categorical([(v, w, m) for v, (w, m) in res.items()])
37+
38+
39+
def sample(self) -> T:
40+
u = rand.rand()
41+
i = np.searchsorted(np.cumsum(self.probs), u)
42+
return self.values[i]
43+
44+
45+
def sort(self) -> "Categorical[T]":
46+
d = self.shrink()
47+
sorted_indices = np.argsort(d.logits)[::-1]
48+
d.values = [d.values[i] for i in sorted_indices]
49+
d.logits = np.array(d.logits)[sorted_indices]
50+
d.probs = np.array(d.probs)[sorted_indices]
51+
d.metadata = [d.metadata[i] for i in sorted_indices]
52+
return d
53+
54+
55+
def viz(dist: Categorical[float], **kwargs):
56+
"""
57+
Visualize a distribution
58+
"""
59+
dist = dist.shrink()
60+
if len(dist.values) < 100:
61+
sns.barplot(x=dist.values, y=dist.probs, errorbar=None, **kwargs)
62+
else:
63+
sns.histplot(
64+
x=dist.values,
65+
weights=dist.probs,
66+
bins=50,
67+
kde=True,
68+
stat="probability",
69+
**kwargs,
70+
)

src/pdl/pdl_infer.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
import yaml
66
from matplotlib import pyplot as plt
77
from mu_ppl import viz
8-
from mu_ppl.distributions import Categorical
98

109
from ._version import version
1110
from .pdl import InterpreterConfig
1211
from .pdl_ast import PdlLocationType, Program, ScopeType, get_default_model_parameters
12+
from .pdl_distributions import Categorical
1313
from .pdl_inference import (
1414
infer_importance_sampling,
1515
infer_importance_sampling_parallel,
16-
infer_rejection,
17-
infer_rejection_parallel,
16+
infer_rejection_sampling,
17+
infer_rejection_sampling_parallel,
1818
infer_smc,
1919
infer_smc_parallel,
2020
)
@@ -42,7 +42,7 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
4242
ppdl_config: Optional[PpdlConfig] = None,
4343
scope: Optional[ScopeType | dict[str, Any]] = None,
4444
loc: Optional[PdlLocationType] = None,
45-
output: Literal["result", "all"] = "result",
45+
# output: Literal["result", "all"] = "result",
4646
) -> Categorical[Any]:
4747
ppdl_config = ppdl_config or PpdlConfig()
4848

@@ -56,46 +56,42 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
5656
config["batch"] = 1
5757
config["event_loop"] = _LOOP
5858

59+
dist: Categorical[Any]
5960
match algo:
6061
case "is":
6162
dist = infer_importance_sampling(
62-
prog, config, scope, loc, output, num_particles=num_particles
63+
prog, config, scope, loc, num_particles=num_particles
6364
)
6465
case "parallel-is":
6566
dist = infer_importance_sampling_parallel(
6667
prog,
6768
config,
6869
scope,
6970
loc,
70-
output,
7171
num_particles=num_particles,
7272
max_workers=max_workers,
7373
)
7474
case "smc":
75-
dist = infer_smc(
76-
prog, config, scope, loc, output, num_particles=num_particles
77-
)
75+
dist = infer_smc(prog, config, scope, loc, num_particles=num_particles)
7876
case "parallel-smc":
7977
dist = infer_smc_parallel(
8078
prog,
8179
config,
8280
scope,
8381
loc,
84-
output,
8582
num_particles=num_particles,
8683
max_workers=max_workers,
8784
)
8885
case "rejection":
89-
dist = infer_rejection(
90-
prog, config, scope, loc, output, num_samples=num_particles
86+
dist = infer_rejection_sampling(
87+
prog, config, scope, loc, num_samples=num_particles
9188
)
9289
case "parallel-rejection":
93-
dist = infer_rejection_parallel(
90+
dist = infer_rejection_sampling_parallel(
9491
prog,
9592
config,
9693
scope,
9794
loc,
98-
output,
9995
num_samples=num_particles,
10096
max_workers=max_workers,
10197
)
@@ -110,10 +106,10 @@ def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-argume
110106
ppdl_config: Optional[PpdlConfig] = None,
111107
scope: Optional[ScopeType | dict[str, Any]] = None,
112108
loc: Optional[PdlLocationType] = None,
113-
output: Literal["result", "all"] = "result",
109+
# output: Literal["result", "all"] = "result",
114110
) -> Any:
115111
program = parse_dict(prog)
116-
result = exec_program(program, config, ppdl_config, scope, loc, output)
112+
result = exec_program(program, config, ppdl_config, scope, loc)
117113
return result
118114

119115

@@ -122,10 +118,10 @@ def exec_str(
122118
config: Optional[InterpreterConfig] = None,
123119
ppdl_config: Optional[PpdlConfig] = None,
124120
scope: Optional[ScopeType | dict[str, Any]] = None,
125-
output: Literal["result", "all"] = "result",
121+
# output: Literal["result", "all"] = "result",
126122
) -> Any:
127123
program, loc = parse_str(prog)
128-
result = exec_program(program, config, ppdl_config, scope, loc, output)
124+
result = exec_program(program, config, ppdl_config, scope, loc)
129125
return result
130126

131127

@@ -134,14 +130,14 @@ def exec_file(
134130
config: Optional[InterpreterConfig] = None,
135131
ppdl_config: Optional[PpdlConfig] = None,
136132
scope: Optional[ScopeType | dict[str, Any]] = None,
137-
output: Literal["result", "all"] = "result",
133+
# output: Literal["result", "all"] = "result",
138134
) -> Any:
139135
program, loc = parse_file(prog)
140136
if config is None:
141137
config = InterpreterConfig()
142138
if config.get("cwd") is None:
143139
config["cwd"] = Path(prog).parent
144-
result = exec_program(program, config, ppdl_config, scope, loc, output)
140+
result = exec_program(program, config, ppdl_config, scope, loc)
145141
return result
146142

147143

0 commit comments

Comments
 (0)