diff --git a/examples/example_data/rbm_mnist.npz b/examples/example_data/rbm_mnist.npz new file mode 100644 index 00000000..cbf80f1a Binary files /dev/null and b/examples/example_data/rbm_mnist.npz differ diff --git a/examples/rbm.py b/examples/rbm.py new file mode 100644 index 00000000..0a0ce86d --- /dev/null +++ b/examples/rbm.py @@ -0,0 +1,71 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.4 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% +# %matplotlib inline +import itertools + +import matplotlib.pyplot as plt +import numpy as np + +from pgmax.fg import graph, groups + +# %% +# Load parameters +params = np.load("example_data/rbm_mnist.npz") +bv = params["bv"] +bh = params["bh"] +W = params["W"] +nv = bv.shape[0] +nh = bh.shape[0] + +# %% +# Build factor graph +visible_variables = groups.NDVariableArray(variable_size=2, shape=(nv,)) +hidden_variables = groups.NDVariableArray(variable_size=2, shape=(nh,)) +fg = graph.FactorGraph( + variables=dict(visible=visible_variables, hidden=hidden_variables), +) +for ii in range(nh): + for jj in range(nv): + fg.add_factor( + [("hidden", ii), ("visible", jj)], + np.array(list(itertools.product(np.arange(2), repeat=2))), + np.array([0, 0, 0, W[ii, jj]]), + ) + +# %% +# Set evidence +init_msgs = fg.get_init_msgs() +init_msgs.evidence["hidden"] = np.stack( + [np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1 +) +init_msgs.evidence["visible"] = np.stack( + [np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1 +) + +# %% +# Run inference and decode +msgs = fg.run_bp(100, 0.5, init_msgs) +map_states = fg.decode_map_states(msgs) + +# %% +# Visualize decodings +img = np.zeros(bv.shape) +for ii in range(nv): + img[ii] = map_states[("visible", ii)] + +img = img.reshape((28, 28)) +plt.imshow(img) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index d90835e2..9f9dcd51 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -257,15 +257,19 @@ class NDVariableArray(VariableGroup): variable_size: int shape: Tuple[int, ...] - def _get_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: + def _get_keys_to_vars(self) -> Dict[Union[int, Tuple[int, ...]], nodes.Variable]: """Function that generates a dictionary mapping keys to variables. Returns: a dictionary mapping all possible keys to different variables. """ - keys_to_vars: Dict[Tuple[int, ...], nodes.Variable] = {} + keys_to_vars: Dict[Union[int, Tuple[int, ...]], nodes.Variable] = {} for key in itertools.product(*[list(range(k)) for k in self.shape]): - keys_to_vars[key] = nodes.Variable(self.variable_size) + if len(key) == 1: + keys_to_vars[key[0]] = nodes.Variable(self.variable_size) + else: + keys_to_vars[key] = nodes.Variable(self.variable_size) + return keys_to_vars def get_vars_to_evidence( diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 63e07d28..dba379d0 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from pgmax.fg import groups +from pgmax.fg import groups, nodes def test_vargroup_list_idx(): @@ -33,6 +33,11 @@ def test_composite_vargroup_evidence(): assert (arr == np.zeros(3, dtype=float)).all() +def test_1dvararray_indexing(): + v_group = groups.NDVariableArray(2, (1,)) + assert isinstance(v_group[0], nodes.Variable) + + def test_ndvararray_evidence_error(): v_group = groups.NDVariableArray(3, (2, 2)) with pytest.raises(ValueError) as verror: diff --git a/tests/test_examples.py b/tests/test_examples.py new file mode 100644 index 00000000..3c88c4c4 --- /dev/null +++ b/tests/test_examples.py @@ -0,0 +1,22 @@ +import os +import subprocess +import sys + +import pytest + +EXAMPLES_DIR = os.path.abspath( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples") +) +EXAMPLES = [ + # "heretic_example.py", + "ising_model.py", + # "rbm.py", + "sanity_check_example.py", +] + + +@pytest.mark.parametrize("example", EXAMPLES) +def test_example(example): + print(f"Running:\npython examples/{example}") + filename = os.path.join(EXAMPLES_DIR, example) + subprocess.check_call([sys.executable, filename])