Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.

Moving to a functional interface #88

Merged
merged 56 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d332170
Compile wiring with individual factors
StannisZhou Oct 23, 2021
d7ac261
Use ordered dict for keys to factors
StannisZhou Oct 23, 2021
82a4c95
Rename keys to variables; Simplify
StannisZhou Oct 23, 2021
efde82d
Store variables to factors for factor graph
StannisZhou Oct 23, 2021
2d5a1a4
Simplify messages manipulation
StannisZhou Oct 23, 2021
fe2f14c
Shorten name
StannisZhou Oct 23, 2021
8fc6461
Don't run ci for regular push
StannisZhou Oct 23, 2021
d757d7c
Rough outline for log potentials
StannisZhou Oct 23, 2021
d334605
Log potentials manipulation
StannisZhou Oct 24, 2021
23583d9
Change Messages to BPState
StannisZhou Oct 24, 2021
0827313
Allow log potentials for individual factors
StannisZhou Oct 24, 2021
a029158
Make BPState independent of factor graph
StannisZhou Oct 24, 2021
f8265e3
New classes for functional interface
StannisZhou Oct 25, 2021
77e49ba
Get rid of default modes
StannisZhou Oct 25, 2021
2dc2635
Functional updates functions
StannisZhou Oct 25, 2021
d247bd8
Functional BP interface
StannisZhou Oct 25, 2021
f023082
Remove old implementation
StannisZhou Oct 25, 2021
b00345f
Use functions for setitem
StannisZhou Oct 25, 2021
ab00866
Implement decode map states
StannisZhou Oct 25, 2021
28f7696
Make Ising model example run again
StannisZhou Oct 25, 2021
d6acb08
Updated ising model notebook
StannisZhou Oct 25, 2021
b590a5e
Make RBM example run again
StannisZhou Oct 25, 2021
0189152
Implement flatten/unflatten for variable groups
StannisZhou Oct 25, 2021
05651c9
Use flatten in evidence updates
StannisZhou Oct 25, 2021
19c59db
flatten/unflatten for factor groups
StannisZhou Oct 25, 2021
c824474
Use flatten in log potentials updates
StannisZhou Oct 25, 2021
70c726c
Simplify decode map states
StannisZhou Oct 25, 2021
0ca1a3d
Fix flatten/unflatten
StannisZhou Oct 25, 2021
eb035cf
Get rid of copy in unflatten
StannisZhou Oct 25, 2021
98fae8f
Update notebooks
StannisZhou Oct 25, 2021
9af46d1
Fix all notebooks
StannisZhou Oct 25, 2021
b1191e6
Add examples for batching and gradients
StannisZhou Oct 25, 2021
33b9899
Separate out decode_map_states
StannisZhou Oct 26, 2021
7f13dc6
Make test_pgmax pass
StannisZhou Oct 26, 2021
4fe7a1b
New test groups
StannisZhou Oct 26, 2021
3d8698d
Fix test groups
StannisZhou Oct 26, 2021
86f467e
New test nodes
StannisZhou Oct 26, 2021
eda93fc
Pass test graph
StannisZhou Oct 26, 2021
5d63a12
Full coverage of graph
StannisZhou Oct 26, 2021
5b7b44e
Support default log_potential_matrix for pairwise factor groups
StannisZhou Oct 26, 2021
4e3856a
Full coverage
StannisZhou Oct 26, 2021
1a35622
Docstrings
StannisZhou Oct 26, 2021
dce2b2f
Docstrings
StannisZhou Oct 26, 2021
802d9b3
Separate add factor functions to clarify
StannisZhou Oct 27, 2021
5ab3a3b
Update examples
StannisZhou Oct 27, 2021
030ca03
Fix tests
StannisZhou Oct 27, 2021
ed95289
Change key to name
StannisZhou Oct 27, 2021
bd8899c
More renaming
StannisZhou Oct 27, 2021
8833e2f
Try to avoid memory leaks
StannisZhou Oct 27, 2021
e80d4c2
functools.partial
StannisZhou Oct 27, 2021
824a37c
Small fixes
StannisZhou Oct 31, 2021
ba55c88
Add raises
StannisZhou Oct 31, 2021
d0f6ad9
Fix mypy error
StannisZhou Oct 31, 2021
b454b95
Raises
StannisZhou Nov 1, 2021
7c4690a
Comments
StannisZhou Nov 1, 2021
32d9c48
Fix tests
StannisZhou Nov 1, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: continuous-integration

on:
push:
branches:
- '*'
pull_request:
branches:
- master
Expand Down
54 changes: 21 additions & 33 deletions examples/heretic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
# ---

# %%
# %matplotlib inline
# Standard Package Imports
from dataclasses import replace
from timeit import default_timer as timer
from typing import Any, List, Tuple

import jax
import jax.numpy as jnp

# %%
# %matplotlib inline
# Standard Package Imports
import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -116,18 +115,17 @@ def binary_connected_variables(
W_pot = W_orig.swapaxes(0, 1)
for k_row in range(3):
for k_col in range(3):
fg.add_factor(
factor_factory=groups.PairwiseFactorGroup,
connected_var_keys=binary_connected_variables(28, 28, k_row, k_col),
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
connected_variable_names=binary_connected_variables(28, 28, k_row, k_col),
log_potential_matrix=W_pot[:, :, k_row, k_col],
)


# %% [markdown]
# # Construct Initial Messages

# %%


def custom_flatten_ordering(Mdown, Mup):
flat_idx = 0
flat_Mdown = Mdown.flatten()
Expand Down Expand Up @@ -177,30 +175,27 @@ def custom_flatten_ordering(Mdown, Mup):

# %% tags=[]
# Run BP
init_msgs = fg.get_init_msgs()
init_msgs.ftov = graph.FToVMessages(
factor_graph=fg,
init_value=jax.device_put(
custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup))
bp_state = replace(
fg.bp_state,
ftov_msgs=graph.FToVMessages(
fg_state=fg.fg_state,
value=jax.device_put(
custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup))
),
StannisZhou marked this conversation as resolved.
Show resolved Hide resolved
),
)
init_msgs.evidence[0] = np.array(bXn_evidence)
init_msgs.evidence[1] = np.array(bHn_evidence)
bp_state.evidence[0] = np.array(bXn_evidence)
bp_state.evidence[1] = np.array(bHn_evidence)
run_bp, _, get_beliefs = graph.BP(bp_state, 500)
bp_start_time = timer()
# Assign evidence to pixel vars
final_msgs = fg.run_bp(
500,
0.5,
init_msgs=init_msgs,
)
bp_arrays = run_bp()
bp_end_time = timer()
print(f"time taken for bp {bp_end_time - bp_start_time}")

# Run inference and convert result to human-readable data structure
data_writeback_start_time = timer()
map_message_dict = fg.decode_map_states(
final_msgs,
)
map_states = graph.decode_map_states(get_beliefs(bp_arrays))
data_writeback_end_time = timer()
print(
f"time taken for data conversion of inference result {data_writeback_end_time - data_writeback_start_time}"
Expand Down Expand Up @@ -236,13 +231,6 @@ def plot_images(images):


# %%
img_arr = np.zeros((1, im_size[0], im_size[1]))

for row in range(im_size[0]):
for col in range(im_size[1]):
img_val = float(map_message_dict[0, row, col])
if img_val == 2.0:
img_val = 0.4
img_arr[0, row, col] = img_val * 1.0

img_arr = map_states[0][None].copy().astype(float)
img_arr[img_arr == 2.0] = 0.4
plot_images(img_arr)
75 changes: 53 additions & 22 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# %%
# %matplotlib inline
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -25,18 +27,18 @@

# %%
variables = groups.NDVariableArray(variable_size=2, shape=(50, 50))
fg = graph.FactorGraph(variables=variables, evidence_default_mode="random")
connected_var_keys = []
fg = graph.FactorGraph(variables=variables)
connected_variable_names = []
for ii in range(50):
for jj in range(50):
kk = (ii + 1) % 50
ll = (jj + 1) % 50
connected_var_keys.append([(ii, jj), (kk, jj)])
connected_var_keys.append([(ii, jj), (ii, ll)])
connected_variable_names.append([(ii, jj), (kk, jj)])
connected_variable_names.append([(ii, jj), (ii, ll)])

fg.add_factor(
factor_factory=groups.PairwiseFactorGroup,
connected_var_keys=connected_var_keys,
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
connected_variable_names=connected_variable_names,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
name="factors",
)
Expand All @@ -45,43 +47,72 @@
# ### Run inference and visualize results

# %%
msgs = fg.run_bp(3000, 0.5)
map_states = fg.decode_map_states(msgs)
img = np.zeros((50, 50))
for key in map_states:
img[key] = map_states[key]
bp_state = fg.bp_state
run_bp, _, get_beliefs = graph.BP(bp_state, 3000)

# %%
bp_arrays = run_bp(
evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)

# %%
img = graph.decode_map_states(get_beliefs(bp_arrays))
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)


# %% [markdown]
# ### Gradients and batching

# %%
def loss(log_potentials_updates, evidence_updates):
bp_arrays = run_bp(
log_potentials_updates=log_potentials_updates, evidence_updates=evidence_updates
)
beliefs = get_beliefs(bp_arrays)
loss = -jnp.sum(beliefs)
return loss


batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0))
log_potentials_grads = jax.jit(jax.grad(loss, argnums=0))

# %%
batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})

# %%
grads = log_potentials_grads(
{"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)

# %% [markdown]
# ### Message and evidence manipulation

# %%
# Query evidence for variable (0, 0)
msgs.evidence[0, 0]
bp_state.evidence[0, 0]

# %%
# Set evidence for variable (0, 0)
msgs.evidence[0, 0] = np.array([1.0, 1.0])
msgs.evidence[0, 0]
bp_state.evidence[0, 0] = np.array([1.0, 1.0])
bp_state.evidence[0, 0]

# %%
# Set evidence for all variables using an array
evidence = np.random.randn(50, 50, 2)
msgs.evidence[:] = evidence
msgs.evidence[10, 10] == evidence[10, 10]
bp_state.evidence[None] = evidence
bp_state.evidence[10, 10] == evidence[10, 10]

# %%
# Query messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]

# %%
# Set messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] = np.array([1.0, 1.0])
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)] = np.array([1.0, 1.0])
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]

# %%
# Uniformly spread expected belief at a variable to all connected factors
msgs.ftov[0, 0] = np.array([1.0, 1.0])
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]
bp_state.ftov_msgs[0, 0] = np.array([1.0, 1.0])
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]
48 changes: 32 additions & 16 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# %matplotlib inline
import itertools

import jax
import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -47,25 +48,40 @@
)

# %%
# Set evidence
init_msgs = fg.get_init_msgs()
init_msgs.evidence["hidden"] = np.stack(
[np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1
)
init_msgs.evidence["visible"] = np.stack(
[np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1
)
run_bp, _, get_beliefs = graph.BP(fg.bp_state, 100)

# %%
# Run inference and decode
msgs = fg.run_bp(100, 0.5, init_msgs)
map_states = fg.decode_map_states(msgs)
# Run inference and decode using vmap
n_samples = 16
bp_arrays = jax.vmap(run_bp, in_axes=0, out_axes=0)(
evidence_updates={
"hidden": np.stack(
[
np.zeros((n_samples,) + bh.shape),
bh + np.random.logistic(size=(n_samples,) + bh.shape),
],
axis=-1,
),
"visible": np.stack(
[
np.zeros((n_samples,) + bv.shape),
bv + np.random.logistic(size=(n_samples,) + bv.shape),
],
axis=-1,
),
}
)
StannisZhou marked this conversation as resolved.
Show resolved Hide resolved
map_states = graph.decode_map_states(
jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
)

# %%
# Visualize decodings
img = np.zeros(bv.shape)
for ii in range(nv):
img[ii] = map_states[("visible", ii)]
fig, ax = plt.subplots(4, 4, figsize=(10, 10))
for ii in range(16):
ax[np.unravel_index(ii, (4, 4))].imshow(
map_states["visible"][ii].copy().reshape((28, 28))
)
ax[np.unravel_index(ii, (4, 4))].axis("off")

img = img.reshape((28, 28))
plt.imshow(img)
fig.tight_layout()
Loading