Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.

Commit

Permalink
Moving to a functional interface (#88)
Browse files Browse the repository at this point in the history
* Compile wiring with individual factors

* Use ordered dict for keys to factors

* Rename keys to variables; Simplify

* Store variables to factors for factor graph

* Simplify messages manipulation

* Shorten name

* Don't run ci for regular push

* Rough outline for log potentials

* Log potentials manipulation

* Change Messages to BPState

* Allow log potentials for individual factors

* Make BPState independent of factor graph

* New classes for functional interface

* Get rid of default modes

* Functional updates functions

* Functional BP interface

* Remove old implementation

* Use functions for setitem

* Implement decode map states

* Make Ising model example run again

* Updated ising model notebook

* Make RBM example run again

* Implement flatten/unflatten for variable groups

* Use flatten in evidence updates

* flatten/unflatten for factor groups

* Use flatten in log potentials updates

* Simplify decode map states

* Fix flatten/unflatten

* Get rid of copy in unflatten

* Update notebooks

* Fix all notebooks

* Add examples for batching and gradients

* Separate out decode_map_states

* Make test_pgmax pass

* New test groups

* Fix test groups

* New test nodes

* Pass test graph

* Full coverage of graph

* Support default log_potential_matrix for pairwise factor groups

* Full coverage

* Docstrings

* Docstrings

* Separate add factor functions to clarify

* Update examples

* Fix tests

* Change key to name

* More renaming

* Try to avoid memory leaks

* functools.partial

* Small fixes

* Add raises

* Fix mypy error

* Raises

* Comments

* Fix tests
  • Loading branch information
StannisZhou authored Nov 1, 2021
1 parent 57afe77 commit e4f7208
Show file tree
Hide file tree
Showing 13 changed files with 1,992 additions and 1,033 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: continuous-integration

on:
push:
branches:
- '*'
pull_request:
branches:
- master
Expand Down
54 changes: 21 additions & 33 deletions examples/heretic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
# ---

# %%
# %matplotlib inline
# Standard Package Imports
from dataclasses import replace
from timeit import default_timer as timer
from typing import Any, List, Tuple

import jax
import jax.numpy as jnp

# %%
# %matplotlib inline
# Standard Package Imports
import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -116,18 +115,17 @@ def binary_connected_variables(
W_pot = W_orig.swapaxes(0, 1)
for k_row in range(3):
for k_col in range(3):
fg.add_factor(
factor_factory=groups.PairwiseFactorGroup,
connected_var_keys=binary_connected_variables(28, 28, k_row, k_col),
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
connected_variable_names=binary_connected_variables(28, 28, k_row, k_col),
log_potential_matrix=W_pot[:, :, k_row, k_col],
)


# %% [markdown]
# # Construct Initial Messages

# %%


def custom_flatten_ordering(Mdown, Mup):
flat_idx = 0
flat_Mdown = Mdown.flatten()
Expand Down Expand Up @@ -177,30 +175,27 @@ def custom_flatten_ordering(Mdown, Mup):

# %% tags=[]
# Run BP
init_msgs = fg.get_init_msgs()
init_msgs.ftov = graph.FToVMessages(
factor_graph=fg,
init_value=jax.device_put(
custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup))
bp_state = replace(
fg.bp_state,
ftov_msgs=graph.FToVMessages(
fg_state=fg.fg_state,
value=jax.device_put(
custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup))
),
),
)
init_msgs.evidence[0] = np.array(bXn_evidence)
init_msgs.evidence[1] = np.array(bHn_evidence)
bp_state.evidence[0] = np.array(bXn_evidence)
bp_state.evidence[1] = np.array(bHn_evidence)
run_bp, _, get_beliefs = graph.BP(bp_state, 500)
bp_start_time = timer()
# Assign evidence to pixel vars
final_msgs = fg.run_bp(
500,
0.5,
init_msgs=init_msgs,
)
bp_arrays = run_bp()
bp_end_time = timer()
print(f"time taken for bp {bp_end_time - bp_start_time}")

# Run inference and convert result to human-readable data structure
data_writeback_start_time = timer()
map_message_dict = fg.decode_map_states(
final_msgs,
)
map_states = graph.decode_map_states(get_beliefs(bp_arrays))
data_writeback_end_time = timer()
print(
f"time taken for data conversion of inference result {data_writeback_end_time - data_writeback_start_time}"
Expand Down Expand Up @@ -236,13 +231,6 @@ def plot_images(images):


# %%
img_arr = np.zeros((1, im_size[0], im_size[1]))

for row in range(im_size[0]):
for col in range(im_size[1]):
img_val = float(map_message_dict[0, row, col])
if img_val == 2.0:
img_val = 0.4
img_arr[0, row, col] = img_val * 1.0

img_arr = map_states[0][None].copy().astype(float)
img_arr[img_arr == 2.0] = 0.4
plot_images(img_arr)
75 changes: 53 additions & 22 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# %%
# %matplotlib inline
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -25,18 +27,18 @@

# %%
variables = groups.NDVariableArray(variable_size=2, shape=(50, 50))
fg = graph.FactorGraph(variables=variables, evidence_default_mode="random")
connected_var_keys = []
fg = graph.FactorGraph(variables=variables)
connected_variable_names = []
for ii in range(50):
for jj in range(50):
kk = (ii + 1) % 50
ll = (jj + 1) % 50
connected_var_keys.append([(ii, jj), (kk, jj)])
connected_var_keys.append([(ii, jj), (ii, ll)])
connected_variable_names.append([(ii, jj), (kk, jj)])
connected_variable_names.append([(ii, jj), (ii, ll)])

fg.add_factor(
factor_factory=groups.PairwiseFactorGroup,
connected_var_keys=connected_var_keys,
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
connected_variable_names=connected_variable_names,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
name="factors",
)
Expand All @@ -45,43 +47,72 @@
# ### Run inference and visualize results

# %%
msgs = fg.run_bp(3000, 0.5)
map_states = fg.decode_map_states(msgs)
img = np.zeros((50, 50))
for key in map_states:
img[key] = map_states[key]
bp_state = fg.bp_state
run_bp, _, get_beliefs = graph.BP(bp_state, 3000)

# %%
bp_arrays = run_bp(
evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)

# %%
img = graph.decode_map_states(get_beliefs(bp_arrays))
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)


# %% [markdown]
# ### Gradients and batching

# %%
def loss(log_potentials_updates, evidence_updates):
bp_arrays = run_bp(
log_potentials_updates=log_potentials_updates, evidence_updates=evidence_updates
)
beliefs = get_beliefs(bp_arrays)
loss = -jnp.sum(beliefs)
return loss


batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0))
log_potentials_grads = jax.jit(jax.grad(loss, argnums=0))

# %%
batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})

# %%
grads = log_potentials_grads(
{"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)

# %% [markdown]
# ### Message and evidence manipulation

# %%
# Query evidence for variable (0, 0)
msgs.evidence[0, 0]
bp_state.evidence[0, 0]

# %%
# Set evidence for variable (0, 0)
msgs.evidence[0, 0] = np.array([1.0, 1.0])
msgs.evidence[0, 0]
bp_state.evidence[0, 0] = np.array([1.0, 1.0])
bp_state.evidence[0, 0]

# %%
# Set evidence for all variables using an array
evidence = np.random.randn(50, 50, 2)
msgs.evidence[:] = evidence
msgs.evidence[10, 10] == evidence[10, 10]
bp_state.evidence[None] = evidence
bp_state.evidence[10, 10] == evidence[10, 10]

# %%
# Query messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]

# %%
# Set messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] = np.array([1.0, 1.0])
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)] = np.array([1.0, 1.0])
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]

# %%
# Uniformly spread expected belief at a variable to all connected factors
msgs.ftov[0, 0] = np.array([1.0, 1.0])
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]
bp_state.ftov_msgs[0, 0] = np.array([1.0, 1.0])
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]
48 changes: 32 additions & 16 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# %matplotlib inline
import itertools

import jax
import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -47,25 +48,40 @@
)

# %%
# Set evidence
init_msgs = fg.get_init_msgs()
init_msgs.evidence["hidden"] = np.stack(
[np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1
)
init_msgs.evidence["visible"] = np.stack(
[np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1
)
run_bp, _, get_beliefs = graph.BP(fg.bp_state, 100)

# %%
# Run inference and decode
msgs = fg.run_bp(100, 0.5, init_msgs)
map_states = fg.decode_map_states(msgs)
# Run inference and decode using vmap
n_samples = 16
bp_arrays = jax.vmap(run_bp, in_axes=0, out_axes=0)(
evidence_updates={
"hidden": np.stack(
[
np.zeros((n_samples,) + bh.shape),
bh + np.random.logistic(size=(n_samples,) + bh.shape),
],
axis=-1,
),
"visible": np.stack(
[
np.zeros((n_samples,) + bv.shape),
bv + np.random.logistic(size=(n_samples,) + bv.shape),
],
axis=-1,
),
}
)
map_states = graph.decode_map_states(
jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
)

# %%
# Visualize decodings
img = np.zeros(bv.shape)
for ii in range(nv):
img[ii] = map_states[("visible", ii)]
fig, ax = plt.subplots(4, 4, figsize=(10, 10))
for ii in range(16):
ax[np.unravel_index(ii, (4, 4))].imshow(
map_states["visible"][ii].copy().reshape((28, 28))
)
ax[np.unravel_index(ii, (4, 4))].axis("off")

img = img.reshape((28, 28))
plt.imshow(img)
fig.tight_layout()
Loading

0 comments on commit e4f7208

Please sign in to comment.