Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.

Commit

Permalink
Example notebook with PMAP sampling of RBMs trained on MNIST digits (#80
Browse files Browse the repository at this point in the history
)

* Special indexing for 1D variable array

* RBM notebook

* Clean up RBM notebook; include data

* Get coverage back to 100%

* Add unit test for running examples

* Fix imports

* Address comments; disable RBM test for speed

* Make notebooks work again
  • Loading branch information
StannisZhou authored Oct 17, 2021
1 parent 174f926 commit 5bad82c
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 4 deletions.
Binary file added examples/example_data/rbm_mnist.npz
Binary file not shown.
71 changes: 71 additions & 0 deletions examples/rbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.4
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# %%
# %matplotlib inline
import itertools

import matplotlib.pyplot as plt
import numpy as np

from pgmax.fg import graph, groups

# %%
# Load parameters
params = np.load("example_data/rbm_mnist.npz")
bv = params["bv"]
bh = params["bh"]
W = params["W"]
nv = bv.shape[0]
nh = bh.shape[0]

# %%
# Build factor graph
visible_variables = groups.NDVariableArray(variable_size=2, shape=(nv,))
hidden_variables = groups.NDVariableArray(variable_size=2, shape=(nh,))
fg = graph.FactorGraph(
variables=dict(visible=visible_variables, hidden=hidden_variables),
)
for ii in range(nh):
for jj in range(nv):
fg.add_factor(
[("hidden", ii), ("visible", jj)],
np.array(list(itertools.product(np.arange(2), repeat=2))),
np.array([0, 0, 0, W[ii, jj]]),
)

# %%
# Set evidence
init_msgs = fg.get_init_msgs()
init_msgs.evidence["hidden"] = np.stack(
[np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1
)
init_msgs.evidence["visible"] = np.stack(
[np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1
)

# %%
# Run inference and decode
msgs = fg.run_bp(100, 0.5, init_msgs)
map_states = fg.decode_map_states(msgs)

# %%
# Visualize decodings
img = np.zeros(bv.shape)
for ii in range(nv):
img[ii] = map_states[("visible", ii)]

img = img.reshape((28, 28))
plt.imshow(img)
10 changes: 7 additions & 3 deletions pgmax/fg/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,19 @@ class NDVariableArray(VariableGroup):
variable_size: int
shape: Tuple[int, ...]

def _get_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]:
def _get_keys_to_vars(self) -> Dict[Union[int, Tuple[int, ...]], nodes.Variable]:
"""Function that generates a dictionary mapping keys to variables.
Returns:
a dictionary mapping all possible keys to different variables.
"""
keys_to_vars: Dict[Tuple[int, ...], nodes.Variable] = {}
keys_to_vars: Dict[Union[int, Tuple[int, ...]], nodes.Variable] = {}
for key in itertools.product(*[list(range(k)) for k in self.shape]):
keys_to_vars[key] = nodes.Variable(self.variable_size)
if len(key) == 1:
keys_to_vars[key[0]] = nodes.Variable(self.variable_size)
else:
keys_to_vars[key] = nodes.Variable(self.variable_size)

return keys_to_vars

def get_vars_to_evidence(
Expand Down
7 changes: 6 additions & 1 deletion tests/fg/test_groups.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from pgmax.fg import groups
from pgmax.fg import groups, nodes


def test_vargroup_list_idx():
Expand Down Expand Up @@ -33,6 +33,11 @@ def test_composite_vargroup_evidence():
assert (arr == np.zeros(3, dtype=float)).all()


def test_1dvararray_indexing():
v_group = groups.NDVariableArray(2, (1,))
assert isinstance(v_group[0], nodes.Variable)


def test_ndvararray_evidence_error():
v_group = groups.NDVariableArray(3, (2, 2))
with pytest.raises(ValueError) as verror:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import subprocess
import sys

import pytest

EXAMPLES_DIR = os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples")
)
EXAMPLES = [
# "heretic_example.py",
"ising_model.py",
# "rbm.py",
"sanity_check_example.py",
]


@pytest.mark.parametrize("example", EXAMPLES)
def test_example(example):
print(f"Running:\npython examples/{example}")
filename = os.path.join(EXAMPLES_DIR, example)
subprocess.check_call([sys.executable, filename])

0 comments on commit 5bad82c

Please sign in to comment.