Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.

First pass for speeding up graph and evidence operations #84

Merged
merged 8 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/sanity_check_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
extra_row_keys: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)]
extra_col_keys: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)]
additional_keys = tuple(extra_row_keys + extra_col_keys)
additional_keys_group = groups.GenericVariableGroup(3, additional_keys)
additional_keys_group = groups.VariableDict(3, additional_keys)

# Combine these two VariableGroups into one CompositeVariableGroup
composite_grid_group = groups.CompositeVariableGroup(
Expand Down
84 changes: 41 additions & 43 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from pgmax.bp import infer
from pgmax.fg import fg_utils, groups, nodes
from pgmax.utils import cached_property


@dataclass
Expand Down Expand Up @@ -85,6 +86,9 @@ def __post_init__(self):
self._total_factor_num_states: int = 0
self._factor_group_to_starts: Dict[groups.FactorGroup, int] = {}

def __hash__(self) -> int:
return hash(tuple(self._factor_groups))

def add_factor(
self,
*args,
Expand Down Expand Up @@ -185,7 +189,7 @@ def get_factor(self, key: Any) -> Tuple[nodes.EnumerationFactor, int]:

return factor, start

@property
@cached_property
def wiring(self) -> nodes.EnumerationWiring:
"""Function to compile wiring for belief propagation.
Expand All @@ -201,7 +205,7 @@ def wiring(self) -> nodes.EnumerationWiring:
wiring = fg_utils.concatenate_enumeration_wirings(wirings)
return wiring

@property
@cached_property
def factor_configs_log_potentials(self) -> np.ndarray:
"""Function to compile potential array for belief propagation..
Expand All @@ -218,10 +222,14 @@ def factor_configs_log_potentials(self) -> np.ndarray:
]
)

@property
@cached_property
def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
"""List of individual factors in the factor graph"""
return sum([factor_group.factors for factor_group in self._factor_groups], ())
factors = []
for factor_group in self._factor_groups:
factors.extend(factor_group.factors)

return tuple(factors)

def get_init_msgs(self) -> Messages:
"""Function to initialize messages.
Expand Down Expand Up @@ -327,12 +335,11 @@ def decode_map_states(self, msgs: Messages) -> Dict[Tuple[Any, ...], int]:
evidence = jax.device_put(msgs.evidence.value)
final_var_states = evidence.at[var_states_for_edges].add(msgs.ftov.value)
var_key_to_map_dict: Dict[Tuple[Any, ...], int] = {}
final_var_states_np = np.array(final_var_states)
for var_key in self._variable_group.keys:
var = self._variable_group[var_key]
start_index = self._vars_to_starts[var]
var_key_to_map_dict[var_key] = np.argmax(
final_var_states_np[start_index : start_index + var.num_states]
var_key_to_map_dict[var_key] = int(
jnp.argmax(final_var_states[start_index : start_index + var.num_states])
StannisZhou marked this conversation as resolved.
Show resolved Hide resolved
)
return var_key_to_map_dict

Expand Down Expand Up @@ -523,31 +530,30 @@ class Evidence:

factor_graph: FactorGraph
default_mode: Optional[str] = None
init_value: Optional[Union[np.ndarray, jnp.ndarray]] = None
value: Optional[Union[np.ndarray, jnp.ndarray]] = None

def __post_init__(self):
self._evidence_updates: Dict[
nodes.Variable, Union[np.ndarray, jnp.ndarray]
] = {}
if self.default_mode is not None and self.init_value is not None:
raise ValueError("Should specify only one of default_mode and init_value.")
if self.default_mode is not None and self.value is not None:
raise ValueError("Should specify only one of default_mode and value.")

if self.default_mode is None and self.init_value is None:
if self.default_mode is None and self.value is None:
self.default_mode = "zeros"

if self.init_value is None and self.default_mode not in ("zeros", "random"):
if self.value is None and self.default_mode not in ("zeros", "random"):
raise ValueError(
f"Unsupported default evidence mode {self.default_mode}. "
"Supported default modes are zeros or random"
)

if self.init_value is None:
if self.value is None:
if self.default_mode == "zeros":
self.init_value = jnp.zeros(self.factor_graph.num_var_states)
self.value = jnp.zeros(self.factor_graph.num_var_states)
else:
self.init_value = jax.device_put(
self.value = jax.device_put(
np.random.gumbel(size=(self.factor_graph.num_var_states,))
)
else:
self.value = jax.device_put(self.value)

def __getitem__(self, key: Any) -> jnp.ndarray:
"""Function to query evidence for a variable
Expand All @@ -559,14 +565,8 @@ def __getitem__(self, key: Any) -> jnp.ndarray:
evidence for the queried variable
"""
variable = self.factor_graph._variable_group[key]
if self.factor_graph._variable_group[key] in self._evidence_updates:
evidence = jax.device_put(self._evidence_updates[variable])
else:
start = self.factor_graph._vars_to_starts[variable]
evidence = jax.device_put(self.init_value)[
start : start + variable.num_states
]

start = self.factor_graph._vars_to_starts[variable]
evidence = jax.device_put(self.value)[start : start + variable.num_states]
return evidence

def __setitem__(
Expand All @@ -591,7 +591,7 @@ def __setitem__(
Note that the size of the final dimension should be the same as
variable_group.variable_size. if key indexes a particular variable, then this array
must be of the same size as variable.num_states
- a dictionary: if key indexes a GenericVariableGroup, then evidence_values
- a dictionary: if key indexes a VariableDict, then evidence_values
must be a dictionary mapping keys of variable_group to np.ndarrays of evidence values.
Note that each np.ndarray in the dictionary values must have the same size as
variable_group.variable_size.
Expand All @@ -604,26 +604,24 @@ def __setitem__(
self.factor_graph._variable_group.variable_group_container[key]
)

self._evidence_updates.update(variable_group.get_vars_to_evidence(evidence))
for var, evidence_val in variable_group.get_vars_to_evidence(
evidence
).items():
start_index = self.factor_graph._vars_to_starts[var]
self.value = (
jax.device_put(self.value)
.at[start_index : start_index + evidence_val.shape[0]]
.set(evidence_val)
)
else:
self._evidence_updates[self.factor_graph._variable_group[key]] = evidence

@property
def value(self) -> jnp.ndarray:
"""Function to generate evidence array
Returns:
Array of shape (num_var_states,) representing the flattened evidence for each variable
"""
evidence = jax.device_put(self.init_value)
for var, evidence_val in self._evidence_updates.items():
var = self.factor_graph._variable_group[key]
start_index = self.factor_graph._vars_to_starts[var]
evidence = evidence.at[start_index : start_index + var.num_states].set(
evidence_val
self.value = (
jax.device_put(self.value)
.at[start_index : start_index + var.num_states]
.set(evidence)
)

return evidence


@dataclass
class Messages:
Expand Down
49 changes: 33 additions & 16 deletions pgmax/fg/groups.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import collections
import itertools
import typing
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Dict,
Hashable,
List,
Mapping,
Optional,
OrderedDict,
Sequence,
Tuple,
Union,
)

import numpy as np

Expand Down Expand Up @@ -71,7 +83,7 @@ def __getitem__(self, key):
else:
return vars_list[0]

def _get_keys_to_vars(self) -> Dict[Any, nodes.Variable]:
def _get_keys_to_vars(self) -> OrderedDict[Any, nodes.Variable]:
"""Function that generates a dictionary mapping keys to variables.
Returns:
Expand Down Expand Up @@ -187,13 +199,13 @@ def __getitem__(self, key):
else:
return vars_list[0]

def _get_keys_to_vars(self) -> Dict[Hashable, nodes.Variable]:
def _get_keys_to_vars(self) -> OrderedDict[Hashable, nodes.Variable]:
"""Function that generates a dictionary mapping keys to variables.
Returns:
a dictionary mapping all possible keys to different variables.
"""
keys_to_vars: Dict[Hashable, nodes.Variable] = {}
keys_to_vars: OrderedDict[Hashable, nodes.Variable] = collections.OrderedDict()
for container_key in self.container_keys:
for variable_group_key in self.variable_group_container[container_key].keys:
if isinstance(variable_group_key, tuple):
Expand Down Expand Up @@ -257,13 +269,17 @@ class NDVariableArray(VariableGroup):
variable_size: int
shape: Tuple[int, ...]

def _get_keys_to_vars(self) -> Dict[Union[int, Tuple[int, ...]], nodes.Variable]:
def _get_keys_to_vars(
self,
) -> OrderedDict[Union[int, Tuple[int, ...]], nodes.Variable]:
"""Function that generates a dictionary mapping keys to variables.
Returns:
a dictionary mapping all possible keys to different variables.
"""
keys_to_vars: Dict[Union[int, Tuple[int, ...]], nodes.Variable] = {}
keys_to_vars: OrderedDict[
Union[int, Tuple[int, ...]], nodes.Variable
] = collections.OrderedDict()
for key in itertools.product(*[list(range(k)) for k in self.shape]):
if len(key) == 1:
keys_to_vars[key[0]] = nodes.Variable(self.variable_size)
Expand Down Expand Up @@ -294,34 +310,35 @@ def get_vars_to_evidence(
f"Got {evidence.shape}."
)

vars_to_evidence = {
self._keys_to_vars[key]: evidence[key] for key in self._keys_to_vars
}
vars_to_evidence = {self._keys_to_vars[self.keys[0]]: evidence.ravel()}
return vars_to_evidence


@dataclass(frozen=True, eq=False)
class GenericVariableGroup(VariableGroup):
"""A generic variable group that contains a set of variables of the same size
class VariableDict(VariableGroup):
"""A variable dictionary that contains a set of variables of the same size
Args:
variable_size: The size of the variables in this variable group
key_tuple: A tuple of all keys in this variable group
variable_names: A tuple of all names of the variables in this variable group
"""

variable_size: int
key_tuple: Tuple[Any, ...]
variable_names: Tuple[Any, ...]

def _get_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]:
def _get_keys_to_vars(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]:
"""Function that generates a dictionary mapping keys to variables.
Returns:
a dictionary mapping all possible keys to different variables.
"""
keys_to_vars: Dict[Tuple[Any, ...], nodes.Variable] = {}
for key in self.key_tuple:
keys_to_vars: OrderedDict[
Tuple[Any, ...], nodes.Variable
] = collections.OrderedDict()
for key in self.variable_names:
keys_to_vars[key] = nodes.Variable(self.variable_size)

return keys_to_vars

def get_vars_to_evidence(
Expand Down
4 changes: 2 additions & 2 deletions tests/fg/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def test_onevar_graph():
v_group = groups.GenericVariableGroup(15, (0,))
v_group = groups.VariableDict(15, (0,))
fg = graph.FactorGraph(v_group)
assert fg._variable_group[0].num_states == 15
with pytest.raises(ValueError) as verror:
Expand All @@ -19,7 +19,7 @@ def test_onevar_graph():

assert "Unsupported default message mode" in str(verror.value)
with pytest.raises(ValueError) as verror:
graph.Evidence(factor_graph=fg, default_mode="zeros", init_value=np.zeros(1))
graph.Evidence(factor_graph=fg, default_mode="zeros", value=np.zeros(1))

assert "Should specify only" in str(verror.value)
with pytest.raises(ValueError) as verror:
Expand Down
12 changes: 6 additions & 6 deletions tests/fg/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@


def test_vargroup_list_idx():
v_group = groups.GenericVariableGroup(15, tuple([0, 1, 2]))
v_group = groups.VariableDict(15, tuple([0, 1, 2]))
assert v_group[[0, 1, 2]][0].num_states == 15


def test_composite_vargroup_valueerror():
v_group1 = groups.GenericVariableGroup(15, tuple([0, 1, 2]))
v_group2 = groups.GenericVariableGroup(15, tuple([0, 1, 2]))
v_group1 = groups.VariableDict(15, tuple([0, 1, 2]))
v_group2 = groups.VariableDict(15, tuple([0, 1, 2]))
comp_var_group = groups.CompositeVariableGroup(tuple([v_group1, v_group2]))
with pytest.raises(ValueError) as verror:
comp_var_group[tuple([0])]
assert "The key needs" in str(verror.value)


def test_composite_vargroup_evidence():
v_group1 = groups.GenericVariableGroup(3, tuple([0, 1, 2]))
v_group1 = groups.VariableDict(3, tuple([0, 1, 2]))
v_group1.container_keys
v_group2 = groups.GenericVariableGroup(3, tuple([0, 1, 2]))
v_group2 = groups.VariableDict(3, tuple([0, 1, 2]))
comp_var_group = groups.CompositeVariableGroup(tuple([v_group1, v_group2]))
vars_to_evidence = comp_var_group.get_vars_to_evidence(
[{0: np.zeros(3)}, {0: np.zeros(3)}]
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_pairwisefacgroup_errors():


def test_generic_evidence_errors():
v_group = groups.GenericVariableGroup(3, tuple([0]))
v_group = groups.VariableDict(3, tuple([0]))
with pytest.raises(ValueError) as verror:
v_group.get_vars_to_evidence({1: np.zeros((1, 1))})
assert "The evidence is referring" in str(verror.value)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pgmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
extra_row_keys: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)]
extra_col_keys: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)]
additional_keys = tuple(extra_row_keys + extra_col_keys)
additional_keys_group = groups.GenericVariableGroup(3, additional_keys)
additional_keys_group = groups.VariableDict(3, additional_keys)

# Combine these two VariableGroups into one CompositeVariableGroup
composite_grid_group = groups.CompositeVariableGroup(
Expand Down