From d332170953a706bce289fb06d7f3d682144555d4 Mon Sep 17 00:00:00 2001 From: stannis Date: Fri, 22 Oct 2021 20:38:40 -0700 Subject: [PATCH 01/56] Compile wiring with individual factors --- pgmax/fg/graph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 7e25ab23..755b0182 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -199,8 +199,7 @@ def wiring(self) -> nodes.EnumerationWiring: compiled wiring from each individual factor """ wirings = [ - factor_group.compile_wiring(self._vars_to_starts) - for factor_group in self._factor_groups + factor.compile_wiring(self._vars_to_starts) for factor in self.factors ] wiring = fg_utils.concatenate_enumeration_wirings(wirings) return wiring From d7ac26140da510e939aa22fe1b4ac56eb99f5fe5 Mon Sep 17 00:00:00 2001 From: stannis Date: Fri, 22 Oct 2021 21:51:04 -0700 Subject: [PATCH 02/56] Use ordered dict for keys to factors --- pgmax/fg/groups.py | 95 ++++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index dfb88fd2..22171f9b 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -446,7 +446,7 @@ def factor_group_log_potentials(self) -> np.ndarray: [factor.factor_configs_log_potentials for factor in self.factors] ) - def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: + def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor]: """Function that generates a dictionary mapping keys to factors. Returns: @@ -502,7 +502,7 @@ class EnumerationFactorGroup(FactorGroup): factor_configs: np.ndarray factor_configs_log_potentials: Optional[np.ndarray] = None - def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: + def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor]: """Function that generates a dictionary mapping keys to factors. Returns: @@ -516,24 +516,35 @@ def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: factor_configs_log_potentials = self.factor_configs_log_potentials if isinstance(self.connected_var_keys, Sequence): - keys_to_factors: Dict[Hashable, nodes.EnumerationFactor] = { - frozenset(self.connected_var_keys[ii]): nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[ii]]), - self.factor_configs, - factor_configs_log_potentials, - ) - for ii in range(len(self.connected_var_keys)) - } + keys_to_factors: OrderedDict[ + Hashable, nodes.EnumerationFactor + ] = collections.OrderedDict( + [ + ( + frozenset(self.connected_var_keys[ii]), + nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[ii]]), + self.factor_configs, + factor_configs_log_potentials, + ), + ) + for ii in range(len(self.connected_var_keys)) + ] + ) else: - keys_to_factors = { - key: nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[key]]), - self.factor_configs, - factor_configs_log_potentials, - ) - for key in self.connected_var_keys - } - + keys_to_factors = collections.OrderedDict( + [ + ( + key, + nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[key]]), + self.factor_configs, + factor_configs_log_potentials, + ), + ) + for key in self.connected_var_keys + ] + ) return keys_to_factors @@ -562,7 +573,7 @@ class PairwiseFactorGroup(FactorGroup): ] log_potential_matrix: np.ndarray - def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: + def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor]: """Function that generates a dictionary mapping keys to factors. Returns: @@ -607,22 +618,34 @@ def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: factor_configs[:, 0], factor_configs[:, 1] ] if isinstance(self.connected_var_keys, Sequence): - keys_to_factors: Dict[Hashable, nodes.EnumerationFactor] = { - frozenset(self.connected_var_keys[ii]): nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[ii]]), - factor_configs, - factor_configs_log_potentials, - ) - for ii in range(len(self.connected_var_keys)) - } + keys_to_factors: OrderedDict[ + Hashable, nodes.EnumerationFactor + ] = collections.OrderedDict( + [ + ( + frozenset(self.connected_var_keys[ii]), + nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[ii]]), + factor_configs, + factor_configs_log_potentials, + ), + ) + for ii in range(len(self.connected_var_keys)) + ] + ) else: - keys_to_factors = { - key: nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[key]]), - factor_configs, - factor_configs_log_potentials, - ) - for key in self.connected_var_keys - } + keys_to_factors = collections.OrderedDict( + [ + ( + key, + nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[key]]), + factor_configs, + factor_configs_log_potentials, + ), + ) + for key in self.connected_var_keys + ] + ) return keys_to_factors From 82a4c951eb24a53d494acce666d49f5ba3dba44d Mon Sep 17 00:00:00 2001 From: stannis Date: Fri, 22 Oct 2021 22:13:53 -0700 Subject: [PATCH 03/56] Rename keys to variables; Simplify --- pgmax/fg/groups.py | 162 ++++++++++++++++++--------------------------- 1 file changed, 64 insertions(+), 98 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 22171f9b..d86825c6 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -5,7 +5,9 @@ from types import MappingProxyType from typing import ( Any, + Collection, Dict, + FrozenSet, Hashable, List, Mapping, @@ -386,36 +388,45 @@ class FactorGroup: all the variables that are connected to this FactorGroup Attributes: - _keys_to_factors: maps factor keys to the corresponding factors + _variables_to_factors: maps set of involved variables to the corresponding factors Raises: ValueError: if connected_var_keys is an empty list """ variable_group: Union[CompositeVariableGroup, VariableGroup] - _keys_to_factors: Mapping[Hashable, nodes.EnumerationFactor] = field(init=False) + _variables_to_factors: Mapping[FrozenSet, nodes.EnumerationFactor] = field( + init=False + ) def __post_init__(self) -> None: """Initializes a tuple of all the factors contained within this FactorGroup.""" object.__setattr__( - self, "_keys_to_factors", MappingProxyType(self._get_keys_to_factors()) + self, + "_variables_to_factors", + MappingProxyType(self._get_variables_to_factors()), ) - def __getitem__(self, key: Hashable) -> nodes.EnumerationFactor: + def __getitem__( + self, + variables: Union[Sequence, Collection], + ) -> nodes.EnumerationFactor: """Function to query individual factors in the factor group Args: - key: a key used to query an individual factor in the factor group + variables: a set of variables, used to query an individual factor in the factor group + involving this set of variables Returns: A queried individual factor """ - if key not in self.keys: + variables = frozenset(variables) + if variables not in self._variables_to_factors: raise ValueError( - f"The queried factor {key} is not present in the factor group" + f"The queried factor {variables} is not present in the factor group" ) - return self._keys_to_factors[key] + return self._variables_to_factors[variables] def compile_wiring( self, vars_to_starts: Mapping[nodes.Variable, int] @@ -446,7 +457,9 @@ def factor_group_log_potentials(self) -> np.ndarray: [factor.factor_configs_log_potentials for factor in self.factors] ) - def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor]: + def _get_variables_to_factors( + self, + ) -> OrderedDict[FrozenSet, nodes.EnumerationFactor]: """Function that generates a dictionary mapping keys to factors. Returns: @@ -456,15 +469,10 @@ def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor] "Please subclass the VariableGroup class and override this method" ) - @cached_property - def keys(self) -> Tuple[Hashable, ...]: - """Returns all keys in the factor group.""" - return tuple(self._keys_to_factors.keys()) - @cached_property def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: """Returns all factors in the factor group.""" - return tuple(self._keys_to_factors.values()) + return tuple(self._variables_to_factors.values()) @cached_property def factor_num_states(self) -> np.ndarray: @@ -495,18 +503,17 @@ class EnumerationFactorGroup(FactorGroup): initialized. """ - connected_var_keys: Union[ - Sequence[List[Tuple[Hashable, ...]]], - Mapping[Any, List[Tuple[Hashable, ...]]], - ] + connected_var_keys: Sequence[List[Tuple[Hashable, ...]]] factor_configs: np.ndarray factor_configs_log_potentials: Optional[np.ndarray] = None - def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor]: - """Function that generates a dictionary mapping keys to factors. + def _get_variables_to_factors( + self, + ) -> OrderedDict[FrozenSet, nodes.EnumerationFactor]: + """Function that generates a dictionary mapping set of involved variables to factors. Returns: - a dictionary mapping all possible keys to different factors. + a dictionary mapping all possible set of involved variables to different factors. """ if self.factor_configs_log_potentials is None: factor_configs_log_potentials = np.zeros( @@ -515,37 +522,20 @@ def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor] else: factor_configs_log_potentials = self.factor_configs_log_potentials - if isinstance(self.connected_var_keys, Sequence): - keys_to_factors: OrderedDict[ - Hashable, nodes.EnumerationFactor - ] = collections.OrderedDict( - [ - ( - frozenset(self.connected_var_keys[ii]), - nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[ii]]), - self.factor_configs, - factor_configs_log_potentials, - ), - ) - for ii in range(len(self.connected_var_keys)) - ] - ) - else: - keys_to_factors = collections.OrderedDict( - [ - ( - key, - nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[key]]), - self.factor_configs, - factor_configs_log_potentials, - ), - ) - for key in self.connected_var_keys - ] - ) - return keys_to_factors + variables_to_factors = collections.OrderedDict( + [ + ( + frozenset(self.connected_var_keys[ii]), + nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[ii]]), + self.factor_configs, + factor_configs_log_potentials, + ), + ) + for ii in range(len(self.connected_var_keys)) + ] + ) + return variables_to_factors @dataclass(frozen=True, eq=False) @@ -567,29 +557,23 @@ class PairwiseFactorGroup(FactorGroup): VariableGroup) whose keys are present in each sub-list from self.connected_var_keys. """ - connected_var_keys: Union[ - Sequence[List[Tuple[Hashable, ...]]], - Mapping[Any, List[Tuple[Hashable, ...]]], - ] + connected_var_keys: Sequence[List[Tuple[Hashable, ...]]] log_potential_matrix: np.ndarray - def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor]: - """Function that generates a dictionary mapping keys to factors. + def _get_variables_to_factors( + self, + ) -> OrderedDict[FrozenSet, nodes.EnumerationFactor]: + """Function that generates a dictionary mapping set of involved variables to factors. Returns: - a dictionary mapping all possible keys to different factors. + a dictionary mapping all possible set of involved variables to different factors. Raises: ValueError: if every sub-list within self.connected_var_keys has len != 2, or if the shape of the log_potential_matrix is not the same as the variable sizes for each variable referenced in each sub-list of self.connected_var_keys """ - if isinstance(self.connected_var_keys, Sequence): - connected_var_keys = self.connected_var_keys - else: - connected_var_keys = tuple(self.connected_var_keys.values()) - - for fac_list in connected_var_keys: + for fac_list in self.connected_var_keys: if len(fac_list) != 2: raise ValueError( "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to" @@ -617,35 +601,17 @@ def _get_keys_to_factors(self) -> OrderedDict[Hashable, nodes.EnumerationFactor] factor_configs_log_potentials = self.log_potential_matrix[ factor_configs[:, 0], factor_configs[:, 1] ] - if isinstance(self.connected_var_keys, Sequence): - keys_to_factors: OrderedDict[ - Hashable, nodes.EnumerationFactor - ] = collections.OrderedDict( - [ - ( - frozenset(self.connected_var_keys[ii]), - nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[ii]]), - factor_configs, - factor_configs_log_potentials, - ), - ) - for ii in range(len(self.connected_var_keys)) - ] - ) - else: - keys_to_factors = collections.OrderedDict( - [ - ( - key, - nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[key]]), - factor_configs, - factor_configs_log_potentials, - ), - ) - for key in self.connected_var_keys - ] - ) - - return keys_to_factors + variables_to_factors = collections.OrderedDict( + [ + ( + frozenset(self.connected_var_keys[ii]), + nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[ii]]), + factor_configs, + factor_configs_log_potentials, + ), + ) + for ii in range(len(self.connected_var_keys)) + ] + ) + return variables_to_factors From efde82dec5ca6a8f08e54117edcc3fbb91a12f79 Mon Sep 17 00:00:00 2001 From: stannis Date: Fri, 22 Oct 2021 22:27:05 -0700 Subject: [PATCH 04/56] Store variables to factors for factor graph --- pgmax/fg/graph.py | 49 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 755b0182..1eb95372 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -2,10 +2,22 @@ from __future__ import annotations +import collections import typing from dataclasses import dataclass from types import MappingProxyType -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Dict, + FrozenSet, + Hashable, + Mapping, + Optional, + OrderedDict, + Sequence, + Tuple, + Union, +) import jax import jax.numpy as jnp @@ -35,7 +47,6 @@ class FactorGraph: Attributes: _variable_group: VariableGroup. contains all involved VariableGroups - _factor_groups: List of added factor groups num_var_states: int. represents the sum of all variable states of all variables in the FactorGraph _vars_to_starts: MappingProxyType[nodes.Variable, int]. maps every variable to an int @@ -81,13 +92,17 @@ def __post_init__(self): } ) self.num_var_states = vars_num_states_cumsum[-1] - self._factor_groups: List[groups.FactorGroup] = [] self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} self._total_factor_num_states: int = 0 - self._factor_group_to_starts: Dict[groups.FactorGroup, int] = {} + self._factor_group_to_starts: OrderedDict[ + groups.FactorGroup, int + ] = collections.OrderedDict() + self._variables_to_factors: OrderedDict[ + FrozenSet, nodes.EnumerationFactor + ] = collections.OrderedDict() def __hash__(self) -> int: - return hash(tuple(self._factor_groups)) + return hash(self.factor_groups) def add_factor( self, @@ -139,7 +154,15 @@ def add_factor( self._variable_group, **kwargs ) - self._factor_groups.append(factor_group) + duplicate_factors = set(factor_group._variables_to_factors).intersection( + set(self._variables_to_factors) + ) + if len(duplicate_factors) > 0: + raise ValueError( + f"Factors involving variables {duplicate_factors} already exist. Please merge the corresponding factors." + ) + + self._variables_to_factors.update(factor_group._variables_to_factors) self._factor_group_to_starts[factor_group] = self._total_factor_num_states self._total_factor_num_states += np.sum(factor_group.factor_num_states) if name is not None: @@ -150,7 +173,6 @@ def get_factor(self, key: Any) -> Tuple[nodes.EnumerationFactor, int]: Args: key: the key for the factor. - The queried factor must be part of an named factor group. Returns: A tuple of length 2, containing the queried factor and its corresponding @@ -217,18 +239,19 @@ def factor_configs_log_potentials(self) -> np.ndarray: return np.concatenate( [ factor_group.factor_group_log_potentials - for factor_group in self._factor_groups + for factor_group in self.factor_groups ] ) @cached_property def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: - """List of individual factors in the factor graph""" - factors = [] - for factor_group in self._factor_groups: - factors.extend(factor_group.factors) + """Tuple of individual factors in the factor graph""" + return tuple(self._variables_to_factors.values()) - return tuple(factors) + @property + def factor_groups(self) -> Tuple[groups.FactorGroup, ...]: + """Tuple of factor groups in the factor graph""" + return tuple(self._factor_group_to_starts.keys()) def get_init_msgs(self) -> Messages: """Function to initialize messages. From 2d5a1a48dc15d815ac8e6a09d804591fadb5d04a Mon Sep 17 00:00:00 2001 From: stannis Date: Fri, 22 Oct 2021 23:00:39 -0700 Subject: [PATCH 05/56] Simplify messages manipulation --- pgmax/fg/graph.py | 156 +++++++++++++++++----------------------------- 1 file changed, 57 insertions(+), 99 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 1eb95372..45e865db 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -100,6 +100,9 @@ def __post_init__(self): self._variables_to_factors: OrderedDict[ FrozenSet, nodes.EnumerationFactor ] = collections.OrderedDict() + self._factor_to_starts: OrderedDict[ + nodes.EnumerationFactor, int + ] = collections.OrderedDict() def __hash__(self) -> int: return hash(self.factor_groups) @@ -154,62 +157,26 @@ def add_factor( self._variable_group, **kwargs ) - duplicate_factors = set(factor_group._variables_to_factors).intersection( - set(self._variables_to_factors) - ) - if len(duplicate_factors) > 0: - raise ValueError( - f"Factors involving variables {duplicate_factors} already exist. Please merge the corresponding factors." - ) - - self._variables_to_factors.update(factor_group._variables_to_factors) self._factor_group_to_starts[factor_group] = self._total_factor_num_states - self._total_factor_num_states += np.sum(factor_group.factor_num_states) - if name is not None: - self._named_factor_groups[name] = factor_group - - def get_factor(self, key: Any) -> Tuple[nodes.EnumerationFactor, int]: - """Function to get an individual factor and start index - - Args: - key: the key for the factor. - - Returns: - A tuple of length 2, containing the queried factor and its corresponding - start index in the flat message array. - """ - if key in self._named_factor_groups: - if len(self._named_factor_groups[key].factors) != 1: - raise ValueError( - f"Invalid factor key {key}. " - "Please provide a key for an individual factor, " - "not a factor group" - ) - - factor_group = self._named_factor_groups[key] - factor = factor_group.factors[0] - start = self._factor_group_to_starts[factor_group] - else: - if not ( - isinstance(key, tuple) - and len(key) == 2 - and key[0] in self._named_factor_groups - ): + factor_num_states_cumsum = np.insert( + factor_group.factor_num_states.cumsum(), 0, 0 + ) + for vv, variables in enumerate(factor_group._variables_to_factors): + if variables in self._variables_to_factors: raise ValueError( - f"Invalid factor key {key}. " - "Please provide a key either for an individual named factor, " - "or a tuple of length 2 specifying name of the factor group " - "and index of individual factors" + f"A factor involving variables {variables} already exists. Please merge the corresponding factors." ) - factor_group = self._named_factor_groups[key[0]] - factor = factor_group[key[1]] - - start = self._factor_group_to_starts[factor_group] + np.sum( - factor_group.factor_num_states[: factor_group.factors.index(factor)] + factor = factor_group._variables_to_factors[variables] + self._variables_to_factors[variables] = factor + self._factor_to_starts[factor] = ( + self._factor_group_to_starts[factor_group] + + factor_num_states_cumsum[vv] ) - return factor, start + self._total_factor_num_states += factor_num_states_cumsum[-1] + if name is not None: + self._named_factor_groups[name] = factor_group @cached_property def wiring(self) -> nodes.EnumerationWiring: @@ -337,7 +304,7 @@ def message_passing_step(msgs, _): msgs_after_bp, _ = jax.lax.scan(message_passing_step, msgs, None, num_iters) return Messages( - ftov=FToVMessages(factor_graph=self, init_value=msgs_after_bp), + ftov=FToVMessages(factor_graph=self, value=msgs_after_bp), evidence=init_msgs.evidence, ) @@ -374,8 +341,8 @@ class FToVMessages: factor_graph: associated factor graph default_mode: default mode for initializing ftov messages. Allowed values include "zeros" and "random" - If init_value is None, defaults to "zeros" - init_value: Optionally specify initial value for ftov messages + If value is None, defaults to "zeros" + value: Optionally specify initial value for ftov messages Attributes: _message_updates: Dict[int, jnp.ndarray]. A dictionary containing @@ -385,28 +352,34 @@ class FToVMessages: factor_graph: FactorGraph default_mode: Optional[str] = None - init_value: Optional[Union[np.ndarray, jnp.ndarray]] = None + value: Optional[Union[np.ndarray, jnp.ndarray]] = None def __post_init__(self): - self._message_updates: Dict[int, jnp.ndarray] = {} - if self.default_mode is not None and self.init_value is not None: - raise ValueError("Should specify only one of default_mode and init_value.") + if self.default_mode is not None and self.value is not None: + raise ValueError("Should specify only one of default_mode and value.") - if self.default_mode is None and self.init_value is None: + if self.default_mode is None and self.value is None: self.default_mode = "zeros" - if self.init_value is None: + if self.value is None: if self.default_mode == "zeros": - self.init_value = np.zeros(self.factor_graph._total_factor_num_states) + self.value = jnp.zeros(self.factor_graph._total_factor_num_states) elif self.default_mode == "random": - self.init_value = np.random.gumbel( - size=(self.factor_graph._total_factor_num_states,) + self.value = jax.device_put( + np.random.gumbel(size=(self.factor_graph._total_factor_num_states,)) ) else: raise ValueError( f"Unsupported default message mode {self.default_mode}. " "Supported default modes are zeros or random" ) + else: + value = jax.device_put(self.value) + if not value.shape == (self.factor_graph._total_factor_num_states,): + raise ValueError( + f"Expected messages shape {(self.factor_graph._total_factor_num_states,)}. " + f"Got {value.shape}." + ) def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: """Function to query messages from a factor to a variable @@ -429,13 +402,12 @@ def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: "keys to get the messages from a named factor to a variable" ) - factor, start = self.factor_graph.get_factor(keys[0]) - if start in self._message_updates: - msgs = self._message_updates[start] - else: - variable = self.factor_graph._variable_group[keys[1]] - msgs = jax.device_put(self.init_value)[start : start + variable.num_states] - + factor = self.factor_graph._variables_to_factors[frozenset(keys[0])] + variable = self.factor_graph._variable_group[keys[1]] + start = self.factor_graph._factor_to_starts[factor] + np.sum( + factor.edges_num_states[: factor.variables.index(variable)] + ) + msgs = jax.device_put(self.value)[start : start + variable.num_states] return jax.device_put(msgs) @typing.overload @@ -475,8 +447,11 @@ def __setitem__(self, keys, data) -> None: and len(keys) == 2 and keys[1] in self.factor_graph._variable_group.keys ): - factor, start = self.factor_graph.get_factor(keys[0]) + factor = self.factor_graph._variables_to_factors[frozenset(keys[0])] variable = self.factor_graph._variable_group[keys[1]] + start = self.factor_graph._factor_to_starts[factor] + np.sum( + factor.edges_num_states[: factor.variables.index(variable)] + ) if data.shape != (variable.num_states,): raise ValueError( f"Given message shape {data.shape} does not match expected " @@ -484,10 +459,11 @@ def __setitem__(self, keys, data) -> None: f"to variable {keys[1]}." ) - self._message_updates[ - start - + np.sum(factor.edges_num_states[: factor.variables.index(variable)]) - ] = data + self.value = ( + jax.device_put(self.value) + .at[start : start + variable.num_states] + .set(data) + ) elif keys in self.factor_graph._variable_group.keys: variable = self.factor_graph._variable_group[keys] if data.shape != (variable.num_states,): @@ -501,7 +477,11 @@ def __setitem__(self, keys, data) -> None: == self.factor_graph._vars_to_starts[variable] )[0] for start in starts: - self._message_updates[start] = data / starts.shape[0] + self.value = ( + jax.device_put(self.value) + .at[start : start + variable.num_states] + .st(data / starts.shape[0]) + ) else: raise ValueError( "Invalid keys for setting messages. " @@ -511,28 +491,6 @@ def __setitem__(self, keys, data) -> None: "beliefs at a variable" ) - @property - def value(self) -> jnp.ndarray: - """Functin to get the current flat message array - - Returns: - The flat message array after initializing (according to default_mode - or init_value) and applying all message updates. - """ - init_value = jax.device_put(self.init_value) - if not init_value.shape == (self.factor_graph._total_factor_num_states,): - raise ValueError( - f"Expected messages shape {(self.factor_graph._total_factor_num_states,)}. " - f"Got {init_value.shape}." - ) - - msgs = init_value - for start in self._message_updates: - data = self._message_updates[start] - msgs = msgs.at[start : start + data.shape[0]].set(data) - - return msgs - @dataclass class Evidence: @@ -542,8 +500,8 @@ class Evidence: factor_graph: associated factor graph default_mode: default mode for initializing evidence. Allowed values include "zeros" and "random" - If init_value is None, defaults to "zeros" - init_value: Optionally specify initial value for evidence + If value is None, defaults to "zeros" + value: Optionally specify initial value for evidence Attributes: _evidence_updates: Dict[nodes.Variable, np.ndarray]. maps every variable to an np.ndarray From fe2f14ca988ff842cf176fc1c680bfbc24fb092b Mon Sep 17 00:00:00 2001 From: stannis Date: Fri, 22 Oct 2021 23:03:37 -0700 Subject: [PATCH 06/56] Shorten name --- pgmax/bp/infer.py | 6 +++--- pgmax/fg/graph.py | 8 +++----- pgmax/fg/groups.py | 24 ++++++++++-------------- pgmax/fg/nodes.py | 12 ++++++------ tests/fg/test_nodes.py | 20 ++++++++++---------- tests/test_pgmax.py | 6 ++---- 6 files changed, 34 insertions(+), 42 deletions(-) diff --git a/pgmax/bp/infer.py b/pgmax/bp/infer.py index e003b46a..f32e04ea 100644 --- a/pgmax/bp/infer.py +++ b/pgmax/bp/infer.py @@ -38,7 +38,7 @@ def pass_var_to_fac_messages( def pass_fac_to_var_messages( vtof_msgs: jnp.ndarray, factor_configs_edge_states: jnp.ndarray, - factor_configs_log_potentials: jnp.ndarray, + log_potentials: jnp.ndarray, num_val_configs: int, ) -> jnp.ndarray: @@ -55,7 +55,7 @@ def pass_fac_to_var_messages( factor_configs_edge_states[ii] contains a pair of global factor_config and edge_state indices factor_configs_edge_states[ii, 0] contains the global factor config index factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index - factor_configs_log_potentials: Array of shape (num_val_configs, ). An entry at index i is the log potential + log_potentials: Array of shape (num_val_configs, ). An entry at index i is the log potential function value for the configuration with global factor config index i. num_val_configs: the total number of valid configurations for factors in the factor graph. @@ -66,7 +66,7 @@ def pass_fac_to_var_messages( jnp.zeros(shape=(num_val_configs,)) .at[factor_configs_edge_states[..., 0]] .add(vtof_msgs[factor_configs_edge_states[..., 1]]) - ) + factor_configs_log_potentials + ) + log_potentials ftov_msgs = ( jnp.full(shape=(vtof_msgs.shape[0],), fill_value=NEG_INF) .at[factor_configs_edge_states[..., 1]] diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 45e865db..7a2c9b98 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -194,7 +194,7 @@ def wiring(self) -> nodes.EnumerationWiring: return wiring @cached_property - def factor_configs_log_potentials(self) -> np.ndarray: + def log_potentials(self) -> np.ndarray: """Function to compile potential array for belief propagation.. If potential array has already beeen compiled, do nothing. @@ -263,9 +263,7 @@ def run_bp( msgs = jax.device_put(init_msgs.ftov.value) evidence = jax.device_put(init_msgs.evidence.value) wiring = jax.device_put(self.wiring) - factor_configs_log_potentials = jax.device_put( - self.factor_configs_log_potentials - ) + log_potentials = jax.device_put(self.log_potentials) max_msg_size = int(jnp.max(wiring.edges_num_states)) # Normalize the messages to ensure the maximum value is 0. @@ -286,7 +284,7 @@ def message_passing_step(msgs, _): ftov_msgs = infer.pass_fac_to_var_messages( vtof_msgs, wiring.factor_configs_edge_states, - factor_configs_log_potentials, + log_potentials, num_val_configs, ) # Use the results of message passing to perform damping and diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index d86825c6..fc471de9 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -453,9 +453,7 @@ def factor_group_log_potentials(self) -> np.ndarray: a jnp array representing the log of the potential function for the factor group """ - return np.concatenate( - [factor.factor_configs_log_potentials for factor in self.factors] - ) + return np.concatenate([factor.log_potentials for factor in self.factors]) def _get_variables_to_factors( self, @@ -489,7 +487,7 @@ class EnumerationFactorGroup(FactorGroup): All factors in the group are assumed to have the same set of valid configurations and the same potential function. Note that the log potential function is assumed to be - uniform 0 unless the inheriting class includes a factor_configs_log_potentials argument. + uniform 0 unless the inheriting class includes a log_potentials argument. Args: connected_var_keys: A list of list of tuples, where each innermost tuple contains a @@ -497,7 +495,7 @@ class EnumerationFactorGroup(FactorGroup): neighboring a particular factor to be added. factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations - factor_configs_log_potentials: Optional array of shape (num_val_configs,). + log_potentials: Optional array of shape (num_val_configs,). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized. @@ -505,7 +503,7 @@ class EnumerationFactorGroup(FactorGroup): connected_var_keys: Sequence[List[Tuple[Hashable, ...]]] factor_configs: np.ndarray - factor_configs_log_potentials: Optional[np.ndarray] = None + log_potentials: Optional[np.ndarray] = None def _get_variables_to_factors( self, @@ -515,12 +513,10 @@ def _get_variables_to_factors( Returns: a dictionary mapping all possible set of involved variables to different factors. """ - if self.factor_configs_log_potentials is None: - factor_configs_log_potentials = np.zeros( - self.factor_configs.shape[0], dtype=float - ) + if self.log_potentials is None: + log_potentials = np.zeros(self.factor_configs.shape[0], dtype=float) else: - factor_configs_log_potentials = self.factor_configs_log_potentials + log_potentials = self.log_potentials variables_to_factors = collections.OrderedDict( [ @@ -529,7 +525,7 @@ def _get_variables_to_factors( nodes.EnumerationFactor( tuple(self.variable_group[self.connected_var_keys[ii]]), self.factor_configs, - factor_configs_log_potentials, + log_potentials, ), ) for ii in range(len(self.connected_var_keys)) @@ -598,7 +594,7 @@ def _get_variables_to_factors( np.arange(self.log_potential_matrix.shape[1]), ) ).T.reshape((-1, 2)) - factor_configs_log_potentials = self.log_potential_matrix[ + log_potentials = self.log_potential_matrix[ factor_configs[:, 0], factor_configs[:, 1] ] variables_to_factors = collections.OrderedDict( @@ -608,7 +604,7 @@ def _get_variables_to_factors( nodes.EnumerationFactor( tuple(self.variable_group[self.connected_var_keys[ii]]), factor_configs, - factor_configs_log_potentials, + log_potentials, ), ) for ii in range(len(self.connected_var_keys)) diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 2f77395f..7256c098 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -65,7 +65,7 @@ class EnumerationFactor: variables: List of involved variables configs: Array of shape (num_val_configs, num_variables) An array containing an explicit enumeration of all valid configurations - factor_configs_log_potentials: Array of shape (num_val_configs,). An array containing + log_potentials: Array of shape (num_val_configs,). An array containing the log of the potential value for every possible configuration Raises: @@ -82,7 +82,7 @@ class EnumerationFactor: variables: Tuple[Variable, ...] configs: np.ndarray - factor_configs_log_potentials: np.ndarray + log_potentials: np.ndarray def __post_init__(self): self.configs.flags.writeable = False @@ -91,9 +91,9 @@ def __post_init__(self): f"Configurations should be integers. Got {self.configs.dtype}." ) - if not np.issubdtype(self.factor_configs_log_potentials.dtype, np.floating): + if not np.issubdtype(self.log_potentials.dtype, np.floating): raise ValueError( - f"Potential should be floats. Got {self.factor_configs_log_potentials.dtype}." + f"Potential should be floats. Got {self.log_potentials.dtype}." ) if len(self.variables) != self.configs.shape[1]: @@ -101,9 +101,9 @@ def __post_init__(self): f"Number of variables {len(self.variables)} doesn't match given configurations {self.configs.shape}" ) - if self.configs.shape[0] != self.factor_configs_log_potentials.shape[0]: + if self.configs.shape[0] != self.log_potentials.shape[0]: raise ValueError( - f"The potential array has {self.factor_configs_log_potentials.shape[0]} rows, which is not equal to the number of configurations ({self.configs.shape[0]})" + f"The potential array has {self.log_potentials.shape[0]} rows, which is not equal to the number of configurations ({self.configs.shape[0]})" ) vars_num_states = np.array([variable.num_states for variable in self.variables]) diff --git a/tests/fg/test_nodes.py b/tests/fg/test_nodes.py index a7a63f01..fb41be5e 100644 --- a/tests/fg/test_nodes.py +++ b/tests/fg/test_nodes.py @@ -7,10 +7,10 @@ def test_enumfactor_configints_error(): v = nodes.Variable(3) configs = np.array([[1.0]]) - factor_configs_log_potentials = np.array([1.0]) + log_potentials = np.array([1.0]) with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v]), configs, factor_configs_log_potentials) + nodes.EnumerationFactor(tuple([v]), configs, log_potentials) assert "Configurations" in str(verror.value) @@ -18,10 +18,10 @@ def test_enumfactor_configints_error(): def test_enumfactor_potentials_error(): v = nodes.Variable(3) configs = np.array([[1]], dtype=int) - factor_configs_log_potentials = np.array([1], dtype=int) + log_potentials = np.array([1], dtype=int) with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v]), configs, factor_configs_log_potentials) + nodes.EnumerationFactor(tuple([v]), configs, log_potentials) assert "Potential" in str(verror.value) @@ -30,10 +30,10 @@ def test_enumfactor_configsshape_error(): v1 = nodes.Variable(3) v2 = nodes.Variable(3) configs = np.array([[1]], dtype=int) - factor_configs_log_potentials = np.array([1.0]) + log_potentials = np.array([1.0]) with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v1, v2]), configs, factor_configs_log_potentials) + nodes.EnumerationFactor(tuple([v1, v2]), configs, log_potentials) assert "Number of variables" in str(verror.value) @@ -41,10 +41,10 @@ def test_enumfactor_configsshape_error(): def test_enumfactor_potentialshape_error(): v = nodes.Variable(3) configs = np.array([[1]], dtype=int) - factor_configs_log_potentials = np.array([1.0, 2.0]) + log_potentials = np.array([1.0, 2.0]) with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v]), configs, factor_configs_log_potentials) + nodes.EnumerationFactor(tuple([v]), configs, log_potentials) assert "The potential array has" in str(verror.value) @@ -53,9 +53,9 @@ def test_enumfactor_configvarsize_error(): v1 = nodes.Variable(3) v2 = nodes.Variable(1) configs = np.array([[-1, 4]], dtype=int) - factor_configs_log_potentials = np.array([1.0]) + log_potentials = np.array([1.0]) with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v1, v2]), configs, factor_configs_log_potentials) + nodes.EnumerationFactor(tuple([v1, v2]), configs, log_potentials) assert "Invalid configurations for given variables" in str(verror.value) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index c023e029..eddf1496 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -303,7 +303,7 @@ def create_valid_suppression_config_arr(suppression_diameter): fg.add_factor( keys=curr_keys, factor_configs=valid_configs_non_supp, - factor_configs_log_potentials=np.zeros( + log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float ), name=(row, col), @@ -360,9 +360,7 @@ def create_valid_suppression_config_arr(suppression_diameter): factor_factory=groups.EnumerationFactorGroup, connected_var_keys=horz_suppression_keys, factor_configs=valid_configs_supp, - factor_configs_log_potentials=np.zeros( - valid_configs_supp.shape[0], dtype=float - ), + log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), ) # Run BP From 8fc6461f5b566e06738d121d5341a13eb69aa057 Mon Sep 17 00:00:00 2001 From: stannis Date: Fri, 22 Oct 2021 23:23:31 -0700 Subject: [PATCH 07/56] Don't run ci for regular push --- .github/workflows/ci.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 831b3373..3eb5c825 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,9 +1,6 @@ name: continuous-integration on: - push: - branches: - - '*' pull_request: branches: - master From d757d7ca67f91f7d08799f782676147ea440e288 Mon Sep 17 00:00:00 2001 From: stannis Date: Sat, 23 Oct 2021 00:06:02 -0700 Subject: [PATCH 08/56] Rough outline for log potentials --- pgmax/fg/graph.py | 118 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 103 insertions(+), 15 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 7a2c9b98..63559243 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -59,7 +59,7 @@ class FactorGraph: We only support setting messages from factors within explicitly named factor groups to connected variables. _total_factor_num_states: int. Current total number of edge states for the added factors. - _factor_group_to_starts: Dict[groups.FactorGroup, int]. Maps a factor group to its + _factor_group_to_msgs_starts: Dict[groups.FactorGroup, int]. Maps a factor group to its corresponding starting index in the flat message array. """ @@ -85,22 +85,31 @@ def __post_init__(self): 0, 0, ) + self.num_var_states = vars_num_states_cumsum[-1] self._vars_to_starts = MappingProxyType( { variable: vars_num_states_cumsum[vv] for vv, variable in enumerate(self._variable_group.variables) } ) - self.num_var_states = vars_num_states_cumsum[-1] self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} + self._variables_to_factors: OrderedDict[ + FrozenSet, nodes.EnumerationFactor + ] = collections.OrderedDict() + # For ftov messages self._total_factor_num_states: int = 0 - self._factor_group_to_starts: OrderedDict[ + self._factor_group_to_msgs_starts: OrderedDict[ groups.FactorGroup, int ] = collections.OrderedDict() - self._variables_to_factors: OrderedDict[ - FrozenSet, nodes.EnumerationFactor + self._factor_to_msgs_starts: OrderedDict[ + nodes.EnumerationFactor, int + ] = collections.OrderedDict() + # For log potentials + self._total_factor_num_configs: int = 0 + self._factor_group_to_potentials_starts: OrderedDict[ + groups.FactorGroup, int ] = collections.OrderedDict() - self._factor_to_starts: OrderedDict[ + self._factor_to_potentials_starts = OrderedDict[ nodes.EnumerationFactor, int ] = collections.OrderedDict() @@ -157,10 +166,14 @@ def add_factor( self._variable_group, **kwargs ) - self._factor_group_to_starts[factor_group] = self._total_factor_num_states + self._factor_group_to_msgs_starts[factor_group] = self._total_factor_num_states + self._factor_group_to_potentials_starts[ + factor_group + ] = self._total_factor_num_configs factor_num_states_cumsum = np.insert( factor_group.factor_num_states.cumsum(), 0, 0 ) + factor_group_num_configs = 0 for vv, variables in enumerate(factor_group._variables_to_factors): if variables in self._variables_to_factors: raise ValueError( @@ -169,12 +182,26 @@ def add_factor( factor = factor_group._variables_to_factors[variables] self._variables_to_factors[variables] = factor - self._factor_to_starts[factor] = ( - self._factor_group_to_starts[factor_group] + self._factor_to_msgs_starts[factor] = ( + self._factor_group_to_msgs_starts[factor_group] + factor_num_states_cumsum[vv] ) + self._factor_group_to_potentials_starts[factor] = ( + self._factor_group_to_potentials_starts[factor_group] + + vv * factor.log_potentials.shape[0] + ) + factor_group_num_configs += factor.log_potentials.shape[0] + + if ( + factor_group_num_configs + != factor_group.factor_group_log_potentials.shape[0] + ): + raise ValueError( + "Factors in a factor group should have the same number of valid configurations." + ) self._total_factor_num_states += factor_num_states_cumsum[-1] + self._total_factor_num_configs += factor_group_num_configs if name is not None: self._named_factor_groups[name] = factor_group @@ -218,7 +245,7 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: @property def factor_groups(self) -> Tuple[groups.FactorGroup, ...]: """Tuple of factor groups in the factor graph""" - return tuple(self._factor_group_to_starts.keys()) + return tuple(self._factor_group_to_msgs_starts.keys()) def get_init_msgs(self) -> Messages: """Function to initialize messages. @@ -331,6 +358,66 @@ def decode_map_states(self, msgs: Messages) -> Dict[Tuple[Any, ...], int]: return var_key_to_map_dict +@dataclass +class LogPotentials: + factor_graph: FactorGraph + value: Optional[Union[np.ndarray, jnp.ndarray]] = None + + def __post_init__(self): + if self.value is None: + self.value = jax.device_put(self.factor_graph.log_potentials) + else: + if not self.value.shape == self.factor_graph.log_potentials.shape: + raise ValueError( + f"Expected log potentials shape shape {self.factor_graph.log_potentials.shape}. " + f"Got {self.value.shape}." + ) + + self.value = jax.device_put(self.value) + + def __getitem__(self, key: Any): + if key in self.factor_graph._named_factor_groups: + factor_group = self.factor_graph._named_factor_groups[key] + start = self.factor_graph._factor_group_to_potentials_starts[factor_group] + log_potentials = jax.device_put(self.value)[ + start : start + factor_group.factor_group_log_potentials.shape[0] + ] + elif frozenset(key) in self.factor_graph._variables_to_factors: + factor = self.factor_graph._variables_to_factors[frozenset(key)] + start = self.factor_graph._factor_to_potentials_starts[factor] + log_potentials = jax.device_put(self.value)[ + start : start + factor.log_potentials.shape[0] + ] + else: + raise ValueError("") + + return log_potentials + + def __setitem__( + self, + key: Any, + data: Union[np.ndarray, jnp.ndarray], + ): + if key in self.factor_graph._named_factor_groups: + factor_group = self.factor_graph._named_factor_groups[key] + start = self.factor_graph._factor_group_to_potentials_starts[factor_group] + self.value = ( + jax.device_put(self.value) + .at[start : start + factor_group.factor_group_log_potentials.shape[0]] + .set(data) + ) + elif frozenset(key) in self.factor_graph._variables_to_factors: + factor = self.factor_graph._variables_to_factors[frozenset(key)] + start = self.factor_graph._factor_to_potentials_starts[factor] + self.value = ( + jax.device_put(self.value) + .at[start : start + factor.log_potentials.shape[0]] + .set(data) + ) + else: + raise ValueError("") + + @dataclass class FToVMessages: """Class for storing and manipulating factor to variable messages. @@ -372,13 +459,14 @@ def __post_init__(self): "Supported default modes are zeros or random" ) else: - value = jax.device_put(self.value) - if not value.shape == (self.factor_graph._total_factor_num_states,): + if not self.value.shape == (self.factor_graph._total_factor_num_states,): raise ValueError( f"Expected messages shape {(self.factor_graph._total_factor_num_states,)}. " - f"Got {value.shape}." + f"Got {self.value.shape}." ) + self.value = jax.device_put(self.value) + def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: """Function to query messages from a factor to a variable @@ -402,7 +490,7 @@ def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: factor = self.factor_graph._variables_to_factors[frozenset(keys[0])] variable = self.factor_graph._variable_group[keys[1]] - start = self.factor_graph._factor_to_starts[factor] + np.sum( + start = self.factor_graph._factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) msgs = jax.device_put(self.value)[start : start + variable.num_states] @@ -447,7 +535,7 @@ def __setitem__(self, keys, data) -> None: ): factor = self.factor_graph._variables_to_factors[frozenset(keys[0])] variable = self.factor_graph._variable_group[keys[1]] - start = self.factor_graph._factor_to_starts[factor] + np.sum( + start = self.factor_graph._factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) if data.shape != (variable.num_states,): From d3346057d532daa13d9a454f1978525992488e99 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 13:52:53 -0700 Subject: [PATCH 09/56] Log potentials manipulation --- pgmax/fg/graph.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 63559243..80dd4c17 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -400,6 +400,12 @@ def __setitem__( ): if key in self.factor_graph._named_factor_groups: factor_group = self.factor_graph._named_factor_groups[key] + if data.shape != factor_group.factor_group_log_potentials.shape: + raise ValueError( + f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " + f"for factor group {key}. Got {data.shape}." + ) + start = self.factor_graph._factor_group_to_potentials_starts[factor_group] self.value = ( jax.device_put(self.value) @@ -408,6 +414,12 @@ def __setitem__( ) elif frozenset(key) in self.factor_graph._variables_to_factors: factor = self.factor_graph._variables_to_factors[frozenset(key)] + if data.shape != factor.log_potentials.shape: + raise ValueError( + f"Expected log potentials shape {factor.log_potentials.shape} " + f"for factor {key}. Got {data.shape}." + ) + start = self.factor_graph._factor_to_potentials_starts[factor] self.value = ( jax.device_put(self.value) From 23583d95cfe6c28163428d70ca1c671c794162b4 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 13:57:32 -0700 Subject: [PATCH 10/56] Change Messages to BPState --- pgmax/fg/graph.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 80dd4c17..079725d3 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -247,13 +247,14 @@ def factor_groups(self) -> Tuple[groups.FactorGroup, ...]: """Tuple of factor groups in the factor graph""" return tuple(self._factor_group_to_msgs_starts.keys()) - def get_init_msgs(self) -> Messages: + def get_init_msgs(self) -> BPState: """Function to initialize messages. Returns: Initialized messages """ - return Messages( + return BPState( + log_potentials=LogPotentials(factor_graph=self), ftov=FToVMessages( factor_graph=self, default_mode=self.messages_default_mode ), @@ -266,8 +267,8 @@ def run_bp( self, num_iters: int, damping_factor: float, - init_msgs: Optional[Messages] = None, - ) -> Messages: + init_msgs: Optional[BPState] = None, + ) -> BPState: """Function to perform belief propagation. Specifically, belief propagation is run for num_iters iterations and @@ -328,12 +329,13 @@ def message_passing_step(msgs, _): return msgs, None msgs_after_bp, _ = jax.lax.scan(message_passing_step, msgs, None, num_iters) - return Messages( + return BPState( + log_potentials=LogPotentials(factor_graph=self), ftov=FToVMessages(factor_graph=self, value=msgs_after_bp), evidence=init_msgs.evidence, ) - def decode_map_states(self, msgs: Messages) -> Dict[Tuple[Any, ...], int]: + def decode_map_states(self, msgs: BPState) -> Dict[Tuple[Any, ...], int]: """Function to computes the output of MAP inference on input messages. The final states are computed based on evidence obtained from the self.get_evidence @@ -702,13 +704,16 @@ def __setitem__( @dataclass -class Messages: - """Container class for factor to variable messages and evidence. +class BPState: + """Container class for belief propagation states, including log potentials, + ftov messages and evidence (unary log potentials). Args: + log_potentials: log potentials of the model ftov: factor to variable messages evidence: evidence """ + log_potentials: LogPotentials ftov: FToVMessages evidence: Evidence From 08273131b7633dcbb056fb096b3091b98252c4c7 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 14:43:28 -0700 Subject: [PATCH 11/56] Allow log potentials for individual factors --- pgmax/fg/graph.py | 2 +- pgmax/fg/groups.py | 70 ++++++++++++++++++++++++++++++++++------------ pgmax/fg/nodes.py | 12 ++++++-- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 079725d3..3f5f424b 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -371,7 +371,7 @@ def __post_init__(self): else: if not self.value.shape == self.factor_graph.log_potentials.shape: raise ValueError( - f"Expected log potentials shape shape {self.factor_graph.log_potentials.shape}. " + f"Expected log potentials shape {self.factor_graph.log_potentials.shape}. " f"Got {self.value.shape}." ) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index fc471de9..d0cfa510 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -495,7 +495,7 @@ class EnumerationFactorGroup(FactorGroup): neighboring a particular factor to be added. factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations - log_potentials: Optional array of shape (num_val_configs,). + log_potentials: Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized. @@ -513,10 +513,25 @@ def _get_variables_to_factors( Returns: a dictionary mapping all possible set of involved variables to different factors. """ + num_factors = len(self.connected_var_keys) + num_val_configs = self.factor_configs.shape[0] if self.log_potentials is None: - log_potentials = np.zeros(self.factor_configs.shape[0], dtype=float) + log_potentials = np.zeros((num_factors, num_val_configs), dtype=float) else: - log_potentials = self.log_potentials + if self.log_potentials.shape != ( + self.factor_configs.shape[0], + ) or self.log_potentials.shape != ( + num_factors, + self.factor_configs.shape[0], + ): + raise ValueError( + f"Expected log potentials shape: {(num_val_configs,)} or {(num_factors, num_val_configs)}. " + f"Got {self.log_potentials.shape}." + ) + + log_potentials = np.broadcast_to( + self.log_potentials, (num_factors, self.factor_configs.shape[0]) + ) variables_to_factors = collections.OrderedDict( [ @@ -525,7 +540,7 @@ def _get_variables_to_factors( nodes.EnumerationFactor( tuple(self.variable_group[self.connected_var_keys[ii]]), self.factor_configs, - log_potentials, + log_potentials[ii], ), ) for ii in range(len(self.connected_var_keys)) @@ -569,34 +584,51 @@ def _get_variables_to_factors( log_potential_matrix is not the same as the variable sizes for each variable referenced in each sub-list of self.connected_var_keys """ + if not ( + self.log_potential_matrix.ndim == 2 or self.log_potential_matrix.ndim == 3 + ): + raise ValueError( + "log_potential_matrix should be either a 2D array, specifying shared parameters for all " + "pairwise factors, or 3D array, specifying parameters for individual pairwise factors. " + f"Got a {self.log_potential_matrix.ndim}D log_potential_matrix array." + ) + + if self.log_potential_matrix.ndim == 3 and self.log_potential_matrix.shape[ + 0 + ] != len(self.connected_var_keys): + raise ValueError( + f"Expected log_potential_matrix for {len(self.connected_var_keys)} factors. " + f"Got log_potential_matrix for {self.log_potential_matrix.shape[0]} factors." + ) + for fac_list in self.connected_var_keys: if len(fac_list) != 2: raise ValueError( "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to" f" more or less than 2 variables ({fac_list})." ) + if not ( - self.log_potential_matrix.shape + self.log_potential_matrix.shape[-2:] == ( self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states, ) ): raise ValueError( - "self.log_potential_matrix must have shape" - + f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " - + f"based on self.connected_var_keys. Instead, it has shape {self.log_potential_matrix.shape}" + f"The specified pairwise factor {fac_list} (with " + f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " + "configurations) does not match the specified log_potential_matrix " + "(with {self.log_potential_matrix.shape[-2:]} configurations)." ) - factor_configs = np.array( - np.meshgrid( - np.arange(self.log_potential_matrix.shape[0]), - np.arange(self.log_potential_matrix.shape[1]), - ) - ).T.reshape((-1, 2)) - log_potentials = self.log_potential_matrix[ - factor_configs[:, 0], factor_configs[:, 1] - ] + factor_configs = np.mgrid[ + : self.log_potential_matrix.shape[0], : self.log_potential_matrix.shape[1] + ].T.reshape((-1, 2)) + log_potential_matrix = np.broadcast_to( + self.log_potential_matrix, + (len(self.connected_var_keys),) + self.log_potential_matrix.shape[-2:], + ) variables_to_factors = collections.OrderedDict( [ ( @@ -604,7 +636,9 @@ def _get_variables_to_factors( nodes.EnumerationFactor( tuple(self.variable_group[self.connected_var_keys[ii]]), factor_configs, - log_potentials, + log_potential_matrix[ + ii, factor_configs[:, 0], factor_configs[:, 1] + ], ), ) for ii in range(len(self.connected_var_keys)) diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 7256c098..14b4c206 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -96,14 +96,22 @@ def __post_init__(self): f"Potential should be floats. Got {self.log_potentials.dtype}." ) + if self.configs.ndim != 2: + raise ValueError( + "configs should be a 2D array containing a list of valid configurations for " + f"EnumerationFactor. Got a configs array of shape {self.configs.shape}." + ) + if len(self.variables) != self.configs.shape[1]: raise ValueError( f"Number of variables {len(self.variables)} doesn't match given configurations {self.configs.shape}" ) - if self.configs.shape[0] != self.log_potentials.shape[0]: + if self.log_potentials.shape != (self.configs.shape[0],): raise ValueError( - f"The potential array has {self.log_potentials.shape[0]} rows, which is not equal to the number of configurations ({self.configs.shape[0]})" + f"Expected log potentials of shape {(self.configs.shape[0],)} for " + f"({self.configs.shape[0]}) valid configurations. Got log potentials of " + f"shape {self.log_potentials.shape}." ) vars_num_states = np.array([variable.num_states for variable in self.variables]) From a02915847f4280fb73491f8b17395e75da3d70e8 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 15:37:01 -0700 Subject: [PATCH 12/56] Make BPState independent of factor graph --- pgmax/fg/graph.py | 179 +++++++++++++++++++++++++--------------------- 1 file changed, 99 insertions(+), 80 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 3f5f424b..b442ba14 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -3,6 +3,7 @@ from __future__ import annotations import collections +import copy import typing from dataclasses import dataclass from types import MappingProxyType @@ -254,6 +255,7 @@ def get_init_msgs(self) -> BPState: Initialized messages """ return BPState( + wiring=self.wiring, log_potentials=LogPotentials(factor_graph=self), ftov=FToVMessages( factor_graph=self, default_mode=self.messages_default_mode @@ -330,6 +332,7 @@ def message_passing_step(msgs, _): msgs_after_bp, _ = jax.lax.scan(message_passing_step, msgs, None, num_iters) return BPState( + wiring=self.wiring, log_potentials=LogPotentials(factor_graph=self), ftov=FToVMessages(factor_graph=self, value=msgs_after_bp), evidence=init_msgs.evidence, @@ -360,33 +363,42 @@ def decode_map_states(self, msgs: BPState) -> Dict[Tuple[Any, ...], int]: return var_key_to_map_dict -@dataclass class LogPotentials: - factor_graph: FactorGraph - value: Optional[Union[np.ndarray, jnp.ndarray]] = None - - def __post_init__(self): - if self.value is None: - self.value = jax.device_put(self.factor_graph.log_potentials) + def __init__( + self, + factor_graph: FactorGraph, + value: Optional[Union[np.ndarray, jnp.ndarray]] = None, + ): + if value is None: + self.value = jax.device_put(factor_graph.log_potentials) else: - if not self.value.shape == self.factor_graph.log_potentials.shape: + if not value.shape == factor_graph.log_potentials.shape: raise ValueError( - f"Expected log potentials shape {self.factor_graph.log_potentials.shape}. " + f"Expected log potentials shape {factor_graph.log_potentials.shape}. " f"Got {self.value.shape}." ) - self.value = jax.device_put(self.value) + self.value = jax.device_put(value) + + self._named_factor_groups = copy.copy(factor_graph._named_factor_groups) + self._factor_group_to_potentials_starts = copy.copy( + factor_graph._factor_group_to_potentials_starts + ) + self._factor_to_potentials_starts = copy.copy( + factor_graph._factor_to_potentials_starts + ) + self._variables_to_factors = copy.copy(factor_graph._variables_to_factors) def __getitem__(self, key: Any): - if key in self.factor_graph._named_factor_groups: - factor_group = self.factor_graph._named_factor_groups[key] - start = self.factor_graph._factor_group_to_potentials_starts[factor_group] + if key in self._named_factor_groups: + factor_group = self._named_factor_groups[key] + start = self._factor_group_to_potentials_starts[factor_group] log_potentials = jax.device_put(self.value)[ start : start + factor_group.factor_group_log_potentials.shape[0] ] - elif frozenset(key) in self.factor_graph._variables_to_factors: - factor = self.factor_graph._variables_to_factors[frozenset(key)] - start = self.factor_graph._factor_to_potentials_starts[factor] + elif frozenset(key) in self._variables_to_factors: + factor = self._variables_to_factors[frozenset(key)] + start = self._factor_to_potentials_starts[factor] log_potentials = jax.device_put(self.value)[ start : start + factor.log_potentials.shape[0] ] @@ -400,29 +412,29 @@ def __setitem__( key: Any, data: Union[np.ndarray, jnp.ndarray], ): - if key in self.factor_graph._named_factor_groups: - factor_group = self.factor_graph._named_factor_groups[key] + if key in self._named_factor_groups: + factor_group = self._named_factor_groups[key] if data.shape != factor_group.factor_group_log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " f"for factor group {key}. Got {data.shape}." ) - start = self.factor_graph._factor_group_to_potentials_starts[factor_group] + start = self._factor_group_to_potentials_starts[factor_group] self.value = ( jax.device_put(self.value) .at[start : start + factor_group.factor_group_log_potentials.shape[0]] .set(data) ) - elif frozenset(key) in self.factor_graph._variables_to_factors: - factor = self.factor_graph._variables_to_factors[frozenset(key)] + elif frozenset(key) in self._variables_to_factors: + factor = self._variables_to_factors[frozenset(key)] if data.shape != factor.log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor.log_potentials.shape} " f"for factor {key}. Got {data.shape}." ) - start = self.factor_graph._factor_to_potentials_starts[factor] + start = self._factor_to_potentials_starts[factor] self.value = ( jax.device_put(self.value) .at[start : start + factor.log_potentials.shape[0]] @@ -432,7 +444,6 @@ def __setitem__( raise ValueError("") -@dataclass class FToVMessages: """Class for storing and manipulating factor to variable messages. @@ -449,37 +460,44 @@ class FToVMessages: Maps starting indices to the message values to update with. """ - factor_graph: FactorGraph - default_mode: Optional[str] = None - value: Optional[Union[np.ndarray, jnp.ndarray]] = None - - def __post_init__(self): - if self.default_mode is not None and self.value is not None: + def __init__( + self, + factor_graph: FactorGraph, + default_mode: Optional[str] = None, + value: Optional[Union[np.ndarray, jnp.ndarray]] = None, + ): + if default_mode is not None and value is not None: raise ValueError("Should specify only one of default_mode and value.") - if self.default_mode is None and self.value is None: - self.default_mode = "zeros" + if default_mode is None and value is None: + default_mode = "zeros" - if self.value is None: - if self.default_mode == "zeros": - self.value = jnp.zeros(self.factor_graph._total_factor_num_states) - elif self.default_mode == "random": + if value is None: + if default_mode == "zeros": + self.value = jnp.zeros(factor_graph._total_factor_num_states) + elif default_mode == "random": self.value = jax.device_put( - np.random.gumbel(size=(self.factor_graph._total_factor_num_states,)) + np.random.gumbel(size=(factor_graph._total_factor_num_states,)) ) else: raise ValueError( - f"Unsupported default message mode {self.default_mode}. " + f"Unsupported default message mode {default_mode}. " "Supported default modes are zeros or random" ) else: - if not self.value.shape == (self.factor_graph._total_factor_num_states,): + if not value.shape == (factor_graph._total_factor_num_states,): raise ValueError( - f"Expected messages shape {(self.factor_graph._total_factor_num_states,)}. " - f"Got {self.value.shape}." + f"Expected messages shape {(factor_graph._total_factor_num_states,)}. " + f"Got {value.shape}." ) - self.value = jax.device_put(self.value) + self.value = jax.device_put(value) + + self._variable_group = factor_graph._variable_group + self._vars_to_starts = factor_graph._vars_to_starts + self._variables_to_factors = copy.copy(factor_graph._variables_to_factors) + self._factor_to_msgs_starts = copy.copy(factor_graph._factor_to_msgs_starts) + self._var_states_for_edges = factor_graph.wiring.var_states_for_edges def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: """Function to query messages from a factor to a variable @@ -495,16 +513,16 @@ def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: if not ( isinstance(keys, tuple) and len(keys) == 2 - and keys[1] in self.factor_graph._variable_group.keys + and keys[1] in self._variable_group.keys ): raise ValueError( f"Invalid keys {keys}. Please specify a tuple of factor, variable " "keys to get the messages from a named factor to a variable" ) - factor = self.factor_graph._variables_to_factors[frozenset(keys[0])] - variable = self.factor_graph._variable_group[keys[1]] - start = self.factor_graph._factor_to_msgs_starts[factor] + np.sum( + factor = self._variables_to_factors[frozenset(keys[0])] + variable = self._variable_group[keys[1]] + start = self._factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) msgs = jax.device_put(self.value)[start : start + variable.num_states] @@ -545,11 +563,11 @@ def __setitem__(self, keys, data) -> None: if ( isinstance(keys, tuple) and len(keys) == 2 - and keys[1] in self.factor_graph._variable_group.keys + and keys[1] in self._variable_group.keys ): - factor = self.factor_graph._variables_to_factors[frozenset(keys[0])] - variable = self.factor_graph._variable_group[keys[1]] - start = self.factor_graph._factor_to_msgs_starts[factor] + np.sum( + factor = self._variables_to_factors[frozenset(keys[0])] + variable = self._variable_group[keys[1]] + start = self._factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) if data.shape != (variable.num_states,): @@ -564,8 +582,8 @@ def __setitem__(self, keys, data) -> None: .at[start : start + variable.num_states] .set(data) ) - elif keys in self.factor_graph._variable_group.keys: - variable = self.factor_graph._variable_group[keys] + elif keys in self._variable_group.keys: + variable = self._variable_group[keys] if data.shape != (variable.num_states,): raise ValueError( f"Given belief shape {data.shape} does not match expected " @@ -573,8 +591,7 @@ def __setitem__(self, keys, data) -> None: ) starts = np.nonzero( - self.factor_graph.wiring.var_states_for_edges - == self.factor_graph._vars_to_starts[variable] + self._var_states_for_edges == self._vars_to_starts[variable] )[0] for start in starts: self.value = ( @@ -592,7 +609,6 @@ def __setitem__(self, keys, data) -> None: ) -@dataclass class Evidence: """Class for storing and manipulating evidence @@ -608,32 +624,36 @@ class Evidence: representing the evidence for that variable """ - factor_graph: FactorGraph - default_mode: Optional[str] = None - value: Optional[Union[np.ndarray, jnp.ndarray]] = None - - def __post_init__(self): - if self.default_mode is not None and self.value is not None: + def __init__( + self, + factor_graph: FactorGraph, + default_mode: Optional[str] = None, + value: Optional[Union[np.ndarray, jnp.ndarray]] = None, + ): + if default_mode is not None and value is not None: raise ValueError("Should specify only one of default_mode and value.") - if self.default_mode is None and self.value is None: - self.default_mode = "zeros" + if default_mode is None and value is None: + default_mode = "zeros" - if self.value is None and self.default_mode not in ("zeros", "random"): + if value is None and default_mode not in ("zeros", "random"): raise ValueError( - f"Unsupported default evidence mode {self.default_mode}. " + f"Unsupported default evidence mode {default_mode}. " "Supported default modes are zeros or random" ) - if self.value is None: - if self.default_mode == "zeros": - self.value = jnp.zeros(self.factor_graph.num_var_states) + if value is None: + if default_mode == "zeros": + self.value = jnp.zeros(factor_graph.num_var_states) else: self.value = jax.device_put( - np.random.gumbel(size=(self.factor_graph.num_var_states,)) + np.random.gumbel(size=(factor_graph.num_var_states,)) ) else: - self.value = jax.device_put(self.value) + self.value = jax.device_put(value) + + self._variable_group = factor_graph._variable_group + self._vars_to_starts = factor_graph._vars_to_starts def __getitem__(self, key: Any) -> jnp.ndarray: """Function to query evidence for a variable @@ -644,8 +664,8 @@ def __getitem__(self, key: Any) -> jnp.ndarray: Returns: evidence for the queried variable """ - variable = self.factor_graph._variable_group[key] - start = self.factor_graph._vars_to_starts[variable] + variable = self._variable_group[key] + start = self._vars_to_starts[variable] evidence = jax.device_put(self.value)[start : start + variable.num_states] return evidence @@ -658,7 +678,7 @@ def __setitem__( Args: key: tuple that represents the index into the VariableGroup - (self.factor_graph._variable_group) that is created when the FactorGraph is instantiated. Note that + (self._variable_group) that is created when the FactorGraph is instantiated. Note that this can be an index referring to an entire VariableGroup (in which case, the evidence is set for the entire VariableGroup at once), or to an individual Variable within the VariableGroup. @@ -676,26 +696,24 @@ def __setitem__( Note that each np.ndarray in the dictionary values must have the same size as variable_group.variable_size. """ - if key in self.factor_graph._variable_group.container_keys: + if key in self._variable_group.container_keys: if key == slice(None): - variable_group = self.factor_graph._variable_group + variable_group = self._variable_group else: - variable_group = ( - self.factor_graph._variable_group.variable_group_container[key] - ) + variable_group = self._variable_group.variable_group_container[key] for var, evidence_val in variable_group.get_vars_to_evidence( evidence ).items(): - start_index = self.factor_graph._vars_to_starts[var] + start_index = self._vars_to_starts[var] self.value = ( jax.device_put(self.value) .at[start_index : start_index + evidence_val.shape[0]] .set(evidence_val) ) else: - var = self.factor_graph._variable_group[key] - start_index = self.factor_graph._vars_to_starts[var] + var = self._variable_group[key] + start_index = self._vars_to_starts[var] self.value = ( jax.device_put(self.value) .at[start_index : start_index + var.num_states] @@ -714,6 +732,7 @@ class BPState: evidence: evidence """ + wiring: nodes.EnumerationWiring log_potentials: LogPotentials ftov: FToVMessages evidence: Evidence From f8265e3c0f6f645609cc0faf8e4cfce22c32783f Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 22:32:16 -0700 Subject: [PATCH 13/56] New classes for functional interface --- pgmax/fg/graph.py | 376 ++++++++++++++++++++++++++-------------------- 1 file changed, 210 insertions(+), 166 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index b442ba14..bb89a40a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -5,7 +5,7 @@ import collections import copy import typing -from dataclasses import dataclass +from dataclasses import dataclass, replace from types import MappingProxyType from typing import ( Any, @@ -69,8 +69,6 @@ class FactorGraph: Sequence[groups.VariableGroup], groups.VariableGroup, ] - messages_default_mode: str = "zeros" - evidence_default_mode: str = "zeros" def __post_init__(self): if isinstance(self.variables, groups.VariableGroup): @@ -248,28 +246,29 @@ def factor_groups(self) -> Tuple[groups.FactorGroup, ...]: """Tuple of factor groups in the factor graph""" return tuple(self._factor_group_to_msgs_starts.keys()) - def get_init_msgs(self) -> BPState: - """Function to initialize messages. - - Returns: - Initialized messages - """ - return BPState( - wiring=self.wiring, - log_potentials=LogPotentials(factor_graph=self), - ftov=FToVMessages( - factor_graph=self, default_mode=self.messages_default_mode - ), - evidence=Evidence( - factor_graph=self, default_mode=self.evidence_default_mode + @cached_property + def fg_state(self) -> FactorGraphState: + return FactorGraphState( + variable_group=self._variable_group, + vars_to_starts=self._vars_to_starts, + num_var_states=self.num_var_states, + total_factor_num_states=self._total_factor_num_states, + variables_to_factors=copy.copy(self._variables_to_factors), + named_factor_groups=copy.copy(self._named_factor_groups), + factor_group_to_potentials_starts=copy.copy( + self._factor_group_to_potentials_starts ), + factor_to_potentials_starts=copy.copy(self._factor_to_potentials_starts), + factor_to_msgs_starts=copy.copy(self._factor_to_msgs_starts), + log_potentials=self.log_potentials, + wiring=self.wiring, ) def run_bp( self, num_iters: int, damping_factor: float, - init_msgs: Optional[BPState] = None, + bp_state: BPState, ) -> BPState: """Function to perform belief propagation. @@ -279,21 +278,17 @@ def run_bp( Args: num_iters: The number of iterations for which to perform message passing damping_factor: The damping factor to use for message updates between one timestep and the next - init_msgs: Initial messages to start the belief propagation. - If None, construct init_msgs by calling self.get_init_msgs() + bp_state: Initial messages to start the belief propagation. Returns: ftov messages after running BP for num_iters iterations """ # Retrieve the necessary data structures from the compiled self.wiring and # convert these to jax arrays. - if init_msgs is None: - init_msgs = self.get_init_msgs() - - msgs = jax.device_put(init_msgs.ftov.value) - evidence = jax.device_put(init_msgs.evidence.value) - wiring = jax.device_put(self.wiring) - log_potentials = jax.device_put(self.log_potentials) + msgs = jax.device_put(bp_state.ftov.value) + evidence = jax.device_put(bp_state.evidence.value) + wiring = jax.device_put(bp_state.fg_state.wiring) + log_potentials = jax.device_put(bp_state.log_potentials.value) max_msg_size = int(jnp.max(wiring.edges_num_states)) # Normalize the messages to ensure the maximum value is 0. @@ -331,74 +326,121 @@ def message_passing_step(msgs, _): return msgs, None msgs_after_bp, _ = jax.lax.scan(message_passing_step, msgs, None, num_iters) - return BPState( - wiring=self.wiring, - log_potentials=LogPotentials(factor_graph=self), - ftov=FToVMessages(factor_graph=self, value=msgs_after_bp), - evidence=init_msgs.evidence, + return replace( + bp_state, + ftov=FToVMessages(fg_state=bp_state.ftov.fg_state, value=msgs_after_bp), ) - def decode_map_states(self, msgs: BPState) -> Dict[Tuple[Any, ...], int]: + def decode_map_states(self, bp_state: BPState) -> Dict[Tuple[Any, ...], int]: """Function to computes the output of MAP inference on input messages. The final states are computed based on evidence obtained from the self.get_evidence method as well as the internal wiring. Args: - msgs: ftov messages for deciding MAP states + bp_state: ftov messages for deciding MAP states Returns: a dictionary mapping each variable key to the MAP states of the corresponding variable """ - var_states_for_edges = jax.device_put(self.wiring.var_states_for_edges) - evidence = jax.device_put(msgs.evidence.value) - final_var_states = evidence.at[var_states_for_edges].add(msgs.ftov.value) + var_states_for_edges = jax.device_put( + bp_state.fg_state.wiring.var_states_for_edges + ) + evidence = jax.device_put(bp_state.evidence.value) + final_var_states = evidence.at[var_states_for_edges].add(bp_state.ftov.value) var_key_to_map_dict: Dict[Tuple[Any, ...], int] = {} - for var_key in self._variable_group.keys: - var = self._variable_group[var_key] - start_index = self._vars_to_starts[var] + for var_key in bp_state.ftov.fg_state.variable_group.keys: + var = bp_state.ftov.fg_state.variable_group[var_key] + start_index = bp_state.ftov.fg_state.vars_to_starts[var] var_key_to_map_dict[var_key] = int( jnp.argmax(final_var_states[start_index : start_index + var.num_states]) ) + return var_key_to_map_dict +@dataclass(frozen=True, eq=False) +class FactorGraphState: + variable_group: groups.VariableGroup + vars_to_starts: Mapping[nodes.Variable, int] + num_var_states: int + total_factor_num_states: int + variables_to_factors: Mapping[FrozenSet, nodes.EnumerationFactor] + named_factor_groups: Mapping[Hashable, groups.FactorGroup] + factor_group_to_potentials_starts: Mapping[groups.FactorGroup, int] + factor_to_potentials_starts: Mapping[nodes.EnumerationFactor, int] + factor_to_msgs_starts: Mapping[nodes.EnumerationFactor, int] + log_potentials: np.ndarray + wiring: nodes.EnumerationWiring + + def __post_init__(self): + for field in self.__dataclass_fields__: + if isinstance(getattr(self, field), np.ndarray): + getattr(self, field).flags.writeable = False + + if isinstance(getattr(self, field), Mapping): + object.__setattr__(self, field, MappingProxyType(self.field)) + + +@dataclass(frozen=True, eq=False) +class BPState: + """Container class for belief propagation states, including log potentials, + ftov messages and evidence (unary log potentials). + + Args: + log_potentials: log potentials of the model + ftov: factor to variable messages + evidence: evidence + """ + + log_potentials: LogPotentials + ftov: FToVMessages + evidence: Evidence + + def __post_init__(self): + if (self.log_potentials.fg_state != self.ftov.fg_state) or ( + self.ftov.fg_state != self.evidence.fg_state + ): + raise ValueError( + "log_potentials, ftov and evidence should be derived from the same fg_state." + ) + + @property + def fg_state(self) -> FactorGraphState: + return self.log_potentials.fg_state + + +@dataclass(frozen=True, eq=False) class LogPotentials: - def __init__( - self, - factor_graph: FactorGraph, - value: Optional[Union[np.ndarray, jnp.ndarray]] = None, - ): - if value is None: - self.value = jax.device_put(factor_graph.log_potentials) + + fg_state: FactorGraphState + default_mode: Optional[str] = None + value: Optional[np.ndarray] = None + + def __post_init__(self): + if self.value is None: + object.__setattr__( + self, "value", jax.device_put(self.fg_state.log_potentials) + ) else: - if not value.shape == factor_graph.log_potentials.shape: + if not self.value.shape == self.fg_state.log_potentials.shape: raise ValueError( - f"Expected log potentials shape {factor_graph.log_potentials.shape}. " + f"Expected log potentials shape {self.fg_state.log_potentials.shape}. " f"Got {self.value.shape}." ) - self.value = jax.device_put(value) - - self._named_factor_groups = copy.copy(factor_graph._named_factor_groups) - self._factor_group_to_potentials_starts = copy.copy( - factor_graph._factor_group_to_potentials_starts - ) - self._factor_to_potentials_starts = copy.copy( - factor_graph._factor_to_potentials_starts - ) - self._variables_to_factors = copy.copy(factor_graph._variables_to_factors) + object.__setattr__(self, "value", jax.device_put(self.value)) def __getitem__(self, key: Any): - if key in self._named_factor_groups: - factor_group = self._named_factor_groups[key] - start = self._factor_group_to_potentials_starts[factor_group] + if key in self.fg_state.named_factor_groups: + factor_group = self.fg_state.named_factor_groups[key] + start = self.fg_state.factor_group_to_potentials_starts[factor_group] log_potentials = jax.device_put(self.value)[ start : start + factor_group.factor_group_log_potentials.shape[0] ] - elif frozenset(key) in self._variables_to_factors: - factor = self._variables_to_factors[frozenset(key)] - start = self._factor_to_potentials_starts[factor] + elif frozenset(key) in self.fg_state.variables_to_factors: + factor = self.fg_state.variables_to_factors[frozenset(key)] + start = self.fg_state.factor_to_potentials_starts[factor] log_potentials = jax.device_put(self.value)[ start : start + factor.log_potentials.shape[0] ] @@ -412,38 +454,43 @@ def __setitem__( key: Any, data: Union[np.ndarray, jnp.ndarray], ): - if key in self._named_factor_groups: - factor_group = self._named_factor_groups[key] + if key in self.fg_state.named_factor_groups: + factor_group = self.fg_state.named_factor_groups[key] if data.shape != factor_group.factor_group_log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " f"for factor group {key}. Got {data.shape}." ) - start = self._factor_group_to_potentials_starts[factor_group] - self.value = ( + start = self.fg_state.factor_group_to_potentials_starts[factor_group] + object.__setattr__( + self, + "value", jax.device_put(self.value) .at[start : start + factor_group.factor_group_log_potentials.shape[0]] - .set(data) + .set(data), ) - elif frozenset(key) in self._variables_to_factors: - factor = self._variables_to_factors[frozenset(key)] + elif frozenset(key) in self.fg_state.variables_to_factors: + factor = self.fg_state.variables_to_factors[frozenset(key)] if data.shape != factor.log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor.log_potentials.shape} " f"for factor {key}. Got {data.shape}." ) - start = self._factor_to_potentials_starts[factor] - self.value = ( + start = self.fg_state.factor_to_potentials_starts[factor] + object.__setattr__( + self, + "value", jax.device_put(self.value) .at[start : start + factor.log_potentials.shape[0]] - .set(data) + .set(data), ) else: raise ValueError("") +@dataclass(frozen=True, eq=False) class FToVMessages: """Class for storing and manipulating factor to variable messages. @@ -460,44 +507,43 @@ class FToVMessages: Maps starting indices to the message values to update with. """ - def __init__( - self, - factor_graph: FactorGraph, - default_mode: Optional[str] = None, - value: Optional[Union[np.ndarray, jnp.ndarray]] = None, - ): - if default_mode is not None and value is not None: + fg_state: FactorGraphState + default_mode: Optional[str] = None + value: Optional[Union[np.ndarray, jnp.ndarray]] = None + + def __post_init__(self): + if self.default_mode is not None and self.value is not None: raise ValueError("Should specify only one of default_mode and value.") - if default_mode is None and value is None: - default_mode = "zeros" + if self.default_mode is None and self.value is None: + object.__setattr__(self, "default_mode", "zeros") - if value is None: - if default_mode == "zeros": - self.value = jnp.zeros(factor_graph._total_factor_num_states) - elif default_mode == "random": - self.value = jax.device_put( - np.random.gumbel(size=(factor_graph._total_factor_num_states,)) + if self.value is None: + if self.default_mode == "zeros": + object.__setattr__( + self, "value", jnp.zeros(self.fg_state.total_factor_num_states) + ) + elif self.default_mode == "random": + object.__setattr__( + self, + "value", + jax.device_put( + np.random.gumbel(size=(self.fg_state.total_factor_num_states,)) + ), ) else: raise ValueError( - f"Unsupported default message mode {default_mode}. " + f"Unsupported default message mode {self.default_mode}. " "Supported default modes are zeros or random" ) else: - if not value.shape == (factor_graph._total_factor_num_states,): + if not self.value.shape == (self.fg_state.total_factor_num_states,): raise ValueError( - f"Expected messages shape {(factor_graph._total_factor_num_states,)}. " - f"Got {value.shape}." + f"Expected messages shape {(self.fg_state.total_factor_num_states,)}. " + f"Got {self.value.shape}." ) - self.value = jax.device_put(value) - - self._variable_group = factor_graph._variable_group - self._vars_to_starts = factor_graph._vars_to_starts - self._variables_to_factors = copy.copy(factor_graph._variables_to_factors) - self._factor_to_msgs_starts = copy.copy(factor_graph._factor_to_msgs_starts) - self._var_states_for_edges = factor_graph.wiring.var_states_for_edges + object.__setattr__(self, "value", jax.device_put(self.value)) def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: """Function to query messages from a factor to a variable @@ -513,16 +559,16 @@ def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: if not ( isinstance(keys, tuple) and len(keys) == 2 - and keys[1] in self._variable_group.keys + and keys[1] in self.fg_state.variable_group.keys ): raise ValueError( f"Invalid keys {keys}. Please specify a tuple of factor, variable " "keys to get the messages from a named factor to a variable" ) - factor = self._variables_to_factors[frozenset(keys[0])] - variable = self._variable_group[keys[1]] - start = self._factor_to_msgs_starts[factor] + np.sum( + factor = self.fg_state.variables_to_factors[frozenset(keys[0])] + variable = self.fg_state.variable_group[keys[1]] + start = self.fg_state.factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) msgs = jax.device_put(self.value)[start : start + variable.num_states] @@ -563,11 +609,11 @@ def __setitem__(self, keys, data) -> None: if ( isinstance(keys, tuple) and len(keys) == 2 - and keys[1] in self._variable_group.keys + and keys[1] in self.fg_state.variable_group.keys ): - factor = self._variables_to_factors[frozenset(keys[0])] - variable = self._variable_group[keys[1]] - start = self._factor_to_msgs_starts[factor] + np.sum( + factor = self.fg_state.variables_to_factors[frozenset(keys[0])] + variable = self.fg_state.variable_group[keys[1]] + start = self.fg_state.factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) if data.shape != (variable.num_states,): @@ -577,13 +623,15 @@ def __setitem__(self, keys, data) -> None: f"to variable {keys[1]}." ) - self.value = ( + object.__setattr__( + self, + "value", jax.device_put(self.value) .at[start : start + variable.num_states] - .set(data) + .set(data), ) - elif keys in self._variable_group.keys: - variable = self._variable_group[keys] + elif keys in self.fg_state.variable_group.keys: + variable = self.fg_state.variable_group[keys] if data.shape != (variable.num_states,): raise ValueError( f"Given belief shape {data.shape} does not match expected " @@ -591,13 +639,16 @@ def __setitem__(self, keys, data) -> None: ) starts = np.nonzero( - self._var_states_for_edges == self._vars_to_starts[variable] + self.fg_state.wiring.var_states_for_edges + == self.fg_state.vars_to_starts[variable] )[0] for start in starts: - self.value = ( + object.__setattr__( + self, + "value", jax.device_put(self.value) .at[start : start + variable.num_states] - .st(data / starts.shape[0]) + .st(data / starts.shape[0]), ) else: raise ValueError( @@ -609,6 +660,7 @@ def __setitem__(self, keys, data) -> None: ) +@dataclass(frozen=True, eq=False) class Evidence: """Class for storing and manipulating evidence @@ -624,36 +676,38 @@ class Evidence: representing the evidence for that variable """ - def __init__( - self, - factor_graph: FactorGraph, - default_mode: Optional[str] = None, - value: Optional[Union[np.ndarray, jnp.ndarray]] = None, - ): - if default_mode is not None and value is not None: + fg_state: FactorGraphState + default_mode: Optional[str] = None + value: Optional[Union[np.ndarray, jnp.ndarray]] = None + + def __post_init__(self): + if self.default_mode is not None and self.value is not None: raise ValueError("Should specify only one of default_mode and value.") - if default_mode is None and value is None: - default_mode = "zeros" + if self.default_mode is None and self.value is None: + object.__setattr__(self, "default_mode", "zeros") - if value is None and default_mode not in ("zeros", "random"): + if self.value is None and self.default_mode not in ("zeros", "random"): raise ValueError( - f"Unsupported default evidence mode {default_mode}. " + f"Unsupported default evidence mode {self.default_mode}. " "Supported default modes are zeros or random" ) - if value is None: - if default_mode == "zeros": - self.value = jnp.zeros(factor_graph.num_var_states) + if self.value is None: + if self.default_mode == "zeros": + object.__setattr__( + self, "value", jnp.zeros(self.fg_state.num_var_states) + ) else: - self.value = jax.device_put( - np.random.gumbel(size=(factor_graph.num_var_states,)) + object.__setattr__( + self, + "value", + jax.device_put( + np.random.gumbel(size=(self.fg_state.num_var_states,)) + ), ) else: - self.value = jax.device_put(value) - - self._variable_group = factor_graph._variable_group - self._vars_to_starts = factor_graph._vars_to_starts + object.__setattr__(self, "value", jax.device_put(self.value)) def __getitem__(self, key: Any) -> jnp.ndarray: """Function to query evidence for a variable @@ -664,8 +718,8 @@ def __getitem__(self, key: Any) -> jnp.ndarray: Returns: evidence for the queried variable """ - variable = self._variable_group[key] - start = self._vars_to_starts[variable] + variable = self.fg_state.variable_group[key] + start = self.fg_state.vars_to_starts[variable] evidence = jax.device_put(self.value)[start : start + variable.num_states] return evidence @@ -678,7 +732,7 @@ def __setitem__( Args: key: tuple that represents the index into the VariableGroup - (self._variable_group) that is created when the FactorGraph is instantiated. Note that + (self.fg_state.variable_group) that is created when the FactorGraph is instantiated. Note that this can be an index referring to an entire VariableGroup (in which case, the evidence is set for the entire VariableGroup at once), or to an individual Variable within the VariableGroup. @@ -696,43 +750,33 @@ def __setitem__( Note that each np.ndarray in the dictionary values must have the same size as variable_group.variable_size. """ - if key in self._variable_group.container_keys: + if key in self.fg_state.variable_group.container_keys: if key == slice(None): - variable_group = self._variable_group + variable_group = self.fg_state.variable_group else: - variable_group = self._variable_group.variable_group_container[key] + assert isinstance( + self.fg_state.variable_group, groups.CompositeVariableGroup + ) + variable_group = self.fg_state.variable_group[key] for var, evidence_val in variable_group.get_vars_to_evidence( evidence ).items(): - start_index = self._vars_to_starts[var] - self.value = ( + start_index = self.fg_state.vars_to_starts[var] + object.__setattr__( + self, + "value", jax.device_put(self.value) .at[start_index : start_index + evidence_val.shape[0]] - .set(evidence_val) + .set(evidence_val), ) else: - var = self._variable_group[key] - start_index = self._vars_to_starts[var] - self.value = ( + var = self.fg_state.variable_group[key] + start_index = self.fg_state.vars_to_starts[var] + object.__setattr__( + self, + "value", jax.device_put(self.value) .at[start_index : start_index + var.num_states] - .set(evidence) + .set(evidence), ) - - -@dataclass -class BPState: - """Container class for belief propagation states, including log potentials, - ftov messages and evidence (unary log potentials). - - Args: - log_potentials: log potentials of the model - ftov: factor to variable messages - evidence: evidence - """ - - wiring: nodes.EnumerationWiring - log_potentials: LogPotentials - ftov: FToVMessages - evidence: Evidence From 77e49bab96c6afbb35d4dea52918ae0566105965 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 22:47:24 -0700 Subject: [PATCH 14/56] Get rid of default modes --- pgmax/fg/graph.py | 73 ++++++++--------------------------------------- 1 file changed, 12 insertions(+), 61 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index bb89a40a..6fed1de0 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -40,11 +40,6 @@ class FactorGraph: For a sequence, the indices of the sequence are used to index the variable groups. Note that if not a single VariableGroup, a CompositeVariableGroup will be created from this input, and the individual VariableGroups will need to be accessed by indexing. - messages_default_mode: default mode for initializing messages. - Allowed values are "zeros" and "random". - evidence_default_mode: default mode for initializing evidence. - Allowed values are "zeros" and "random". - Any variable whose evidence was not explicitly specified using 'set_evidence' Attributes: _variable_group: VariableGroup. contains all involved VariableGroups @@ -264,6 +259,14 @@ def fg_state(self) -> FactorGraphState: wiring=self.wiring, ) + @property + def bp_state(self) -> BPState: + return BPState( + log_potentials=LogPotentials(fg_state=self.fg_state), + ftov=FToVMessages(fg_state=self.fg_state), + evidence=Evidence(fg_state=self.fg_state), + ) + def run_bp( self, num_iters: int, @@ -414,7 +417,6 @@ def fg_state(self) -> FactorGraphState: class LogPotentials: fg_state: FactorGraphState - default_mode: Optional[str] = None value: Optional[np.ndarray] = None def __post_init__(self): @@ -496,9 +498,6 @@ class FToVMessages: Args: factor_graph: associated factor graph - default_mode: default mode for initializing ftov messages. - Allowed values include "zeros" and "random" - If value is None, defaults to "zeros" value: Optionally specify initial value for ftov messages Attributes: @@ -508,34 +507,13 @@ class FToVMessages: """ fg_state: FactorGraphState - default_mode: Optional[str] = None value: Optional[Union[np.ndarray, jnp.ndarray]] = None def __post_init__(self): - if self.default_mode is not None and self.value is not None: - raise ValueError("Should specify only one of default_mode and value.") - - if self.default_mode is None and self.value is None: - object.__setattr__(self, "default_mode", "zeros") - if self.value is None: - if self.default_mode == "zeros": - object.__setattr__( - self, "value", jnp.zeros(self.fg_state.total_factor_num_states) - ) - elif self.default_mode == "random": - object.__setattr__( - self, - "value", - jax.device_put( - np.random.gumbel(size=(self.fg_state.total_factor_num_states,)) - ), - ) - else: - raise ValueError( - f"Unsupported default message mode {self.default_mode}. " - "Supported default modes are zeros or random" - ) + object.__setattr__( + self, "value", jnp.zeros(self.fg_state.total_factor_num_states) + ) else: if not self.value.shape == (self.fg_state.total_factor_num_states,): raise ValueError( @@ -666,9 +644,6 @@ class Evidence: Args: factor_graph: associated factor graph - default_mode: default mode for initializing evidence. - Allowed values include "zeros" and "random" - If value is None, defaults to "zeros" value: Optionally specify initial value for evidence Attributes: @@ -677,35 +652,11 @@ class Evidence: """ fg_state: FactorGraphState - default_mode: Optional[str] = None value: Optional[Union[np.ndarray, jnp.ndarray]] = None def __post_init__(self): - if self.default_mode is not None and self.value is not None: - raise ValueError("Should specify only one of default_mode and value.") - - if self.default_mode is None and self.value is None: - object.__setattr__(self, "default_mode", "zeros") - - if self.value is None and self.default_mode not in ("zeros", "random"): - raise ValueError( - f"Unsupported default evidence mode {self.default_mode}. " - "Supported default modes are zeros or random" - ) - if self.value is None: - if self.default_mode == "zeros": - object.__setattr__( - self, "value", jnp.zeros(self.fg_state.num_var_states) - ) - else: - object.__setattr__( - self, - "value", - jax.device_put( - np.random.gumbel(size=(self.fg_state.num_var_states,)) - ), - ) + object.__setattr__(self, "value", jnp.zeros(self.fg_state.num_var_states)) else: object.__setattr__(self, "value", jax.device_put(self.value)) From 2dc26353daf977fe2b79bdccb422cc03bb84246a Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 23:12:25 -0700 Subject: [PATCH 15/56] Functional updates functions --- pgmax/fg/graph.py | 132 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 124 insertions(+), 8 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 6fed1de0..d7984164 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -413,6 +413,44 @@ def fg_state(self) -> FactorGraphState: return self.log_potentials.fg_state +@jax.partial(jax.jit, static_argnames="fg_state") +def update_log_potentials( + log_potentials: jnp.ndarray, + updates: Dict[Any, jnp.ndarray], + fg_state: FactorGraphState, +) -> jnp.ndarray: + for key in updates: + data = updates[key] + if key in fg_state.named_factor_groups: + factor_group = fg_state.named_factor_groups[key] + if data.shape != factor_group.factor_group_log_potentials.shape: + raise ValueError( + f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " + f"for factor group {key}. Got {data.shape}." + ) + + start = fg_state.factor_group_to_potentials_starts[factor_group] + log_potentials = log_potentials.at[ + start : start + factor_group.factor_group_log_potentials.shape[0] + ].set(data) + elif frozenset(key) in fg_state.variables_to_factors: + factor = fg_state.variables_to_factors[frozenset(key)] + if data.shape != factor.log_potentials.shape: + raise ValueError( + f"Expected log potentials shape {factor.log_potentials.shape} " + f"for factor {key}. Got {data.shape}." + ) + + start = fg_state.factor_to_potentials_starts[factor] + log_potentials = log_potentials.at[ + start : start + factor.log_potentials.shape[0] + ].set(data) + else: + raise ValueError(f"Invalid key {key} for log potentials updates.") + + return log_potentials + + @dataclass(frozen=True, eq=False) class LogPotentials: @@ -492,6 +530,58 @@ def __setitem__( raise ValueError("") +@jax.partial(jax.jit, static_argnames="fg_state") +def update_ftov_msgs( + ftov_msgs: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState +) -> jnp.ndarray: + for keys in updates: + data = updates[keys] + if ( + isinstance(keys, tuple) + and len(keys) == 2 + and keys[1] in fg_state.variable_group.keys + ): + factor = fg_state.variables_to_factors[frozenset(keys[0])] + variable = fg_state.variable_group[keys[1]] + start = fg_state.factor_to_msgs_starts[factor] + np.sum( + factor.edges_num_states[: factor.variables.index(variable)] + ) + if data.shape != (variable.num_states,): + raise ValueError( + f"Given message shape {data.shape} does not match expected " + f"shape f{(variable.num_states,)} from factor {keys[0]} " + f"to variable {keys[1]}." + ) + + ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set(data) + elif keys in fg_state.variable_group.keys: + variable = fg_state.variable_group[keys] + if data.shape != (variable.num_states,): + raise ValueError( + f"Given belief shape {data.shape} does not match expected " + f"shape f{(variable.num_states,)} for variable {keys}." + ) + + starts = np.nonzero( + fg_state.wiring.var_states_for_edges + == fg_state.vars_to_starts[variable] + )[0] + for start in starts: + ftov_msgs = ftov_msgs.at[start : start + variable.num_states].st( + data / starts.shape[0] + ) + else: + raise ValueError( + "Invalid keys for setting messages. " + "Supported keys include a tuple of length 2 with factor " + "and variable keys for directly setting factor to variable " + "messages, or a valid variable key for spreading expected " + "beliefs at a variable" + ) + + return ftov_msgs + + @dataclass(frozen=True, eq=False) class FToVMessages: """Class for storing and manipulating factor to variable messages. @@ -638,6 +728,34 @@ def __setitem__(self, keys, data) -> None: ) +@jax.partial(jax.jit, static_argnames="fg_state") +def update_evidence( + evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState +) -> jnp.ndarray: + for key in updates: + data = updates[key] + if key in fg_state.variable_group.container_keys: + if key == slice(None): + variable_group = fg_state.variable_group + else: + assert isinstance( + fg_state.variable_group, groups.CompositeVariableGroup + ) + variable_group = fg_state.variable_group[key] + + for var, evidence_val in variable_group.get_vars_to_evidence(data).items(): + start_index = fg_state.vars_to_starts[var] + evidence = evidence.at[ + start_index : start_index + evidence_val.shape[0] + ].set(evidence_val) + else: + var = fg_state.variable_group[key] + start_index = fg_state.vars_to_starts[var] + evidence = evidence.at[start_index : start_index + var.num_states].set(data) + + return evidence + + @dataclass(frozen=True, eq=False) class Evidence: """Class for storing and manipulating evidence @@ -677,7 +795,7 @@ def __getitem__(self, key: Any) -> jnp.ndarray: def __setitem__( self, key: Any, - evidence: Union[Dict[Hashable, np.ndarray], np.ndarray], + data: Union[Dict[Hashable, np.ndarray], np.ndarray], ) -> None: """Function to update the evidence for variables @@ -687,16 +805,16 @@ def __setitem__( this can be an index referring to an entire VariableGroup (in which case, the evidence is set for the entire VariableGroup at once), or to an individual Variable within the VariableGroup. - evidence: a container for np.ndarrays representing the evidence + data: a container for np.ndarrays representing the evidence Currently supported containers are: - - an np.ndarray: if key indexes an NDVariableArray, then evidence_values + - an np.ndarray: if key indexes an NDVariableArray, then data can simply be an np.ndarray with num_var_array_dims + 1 dimensions where num_var_array_dims is the number of dimensions of the NDVariableArray, and the +1 represents a dimension (that should be the final dimension) for the evidence. Note that the size of the final dimension should be the same as variable_group.variable_size. if key indexes a particular variable, then this array must be of the same size as variable.num_states - - a dictionary: if key indexes a VariableDict, then evidence_values + - a dictionary: if key indexes a VariableDict, then data must be a dictionary mapping keys of variable_group to np.ndarrays of evidence values. Note that each np.ndarray in the dictionary values must have the same size as variable_group.variable_size. @@ -710,9 +828,7 @@ def __setitem__( ) variable_group = self.fg_state.variable_group[key] - for var, evidence_val in variable_group.get_vars_to_evidence( - evidence - ).items(): + for var, evidence_val in variable_group.get_vars_to_evidence(data).items(): start_index = self.fg_state.vars_to_starts[var] object.__setattr__( self, @@ -729,5 +845,5 @@ def __setitem__( "value", jax.device_put(self.value) .at[start_index : start_index + var.num_states] - .set(evidence), + .set(data), ) From d247bd867c342cebbe8b0f8fe8ac51856419bfb7 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 23:19:33 -0700 Subject: [PATCH 16/56] Functional BP interface --- pgmax/fg/graph.py | 28 +++++++------ pgmax/fg/transforms.py | 93 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 12 deletions(-) create mode 100644 pgmax/fg/transforms.py diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index d7984164..1ccf0821 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -263,7 +263,7 @@ def fg_state(self) -> FactorGraphState: def bp_state(self) -> BPState: return BPState( log_potentials=LogPotentials(fg_state=self.fg_state), - ftov=FToVMessages(fg_state=self.fg_state), + ftov_msgs=FToVMessages(fg_state=self.fg_state), evidence=Evidence(fg_state=self.fg_state), ) @@ -288,7 +288,7 @@ def run_bp( """ # Retrieve the necessary data structures from the compiled self.wiring and # convert these to jax arrays. - msgs = jax.device_put(bp_state.ftov.value) + msgs = jax.device_put(bp_state.ftov_msgs.value) evidence = jax.device_put(bp_state.evidence.value) wiring = jax.device_put(bp_state.fg_state.wiring) log_potentials = jax.device_put(bp_state.log_potentials.value) @@ -331,7 +331,9 @@ def message_passing_step(msgs, _): msgs_after_bp, _ = jax.lax.scan(message_passing_step, msgs, None, num_iters) return replace( bp_state, - ftov=FToVMessages(fg_state=bp_state.ftov.fg_state, value=msgs_after_bp), + ftov_msgs=FToVMessages( + fg_state=bp_state.ftov_msgs.fg_state, value=msgs_after_bp + ), ) def decode_map_states(self, bp_state: BPState) -> Dict[Tuple[Any, ...], int]: @@ -350,11 +352,13 @@ def decode_map_states(self, bp_state: BPState) -> Dict[Tuple[Any, ...], int]: bp_state.fg_state.wiring.var_states_for_edges ) evidence = jax.device_put(bp_state.evidence.value) - final_var_states = evidence.at[var_states_for_edges].add(bp_state.ftov.value) + final_var_states = evidence.at[var_states_for_edges].add( + bp_state.ftov_msgs.value + ) var_key_to_map_dict: Dict[Tuple[Any, ...], int] = {} - for var_key in bp_state.ftov.fg_state.variable_group.keys: - var = bp_state.ftov.fg_state.variable_group[var_key] - start_index = bp_state.ftov.fg_state.vars_to_starts[var] + for var_key in bp_state.ftov_msgs.fg_state.variable_group.keys: + var = bp_state.ftov_msgs.fg_state.variable_group[var_key] + start_index = bp_state.ftov_msgs.fg_state.vars_to_starts[var] var_key_to_map_dict[var_key] = int( jnp.argmax(final_var_states[start_index : start_index + var.num_states]) ) @@ -392,20 +396,20 @@ class BPState: Args: log_potentials: log potentials of the model - ftov: factor to variable messages + ftov_msgs: factor to variable messages evidence: evidence """ log_potentials: LogPotentials - ftov: FToVMessages + ftov_msgs: FToVMessages evidence: Evidence def __post_init__(self): - if (self.log_potentials.fg_state != self.ftov.fg_state) or ( - self.ftov.fg_state != self.evidence.fg_state + if (self.log_potentials.fg_state != self.ftov_msgs.fg_state) or ( + self.ftov_msgs.fg_state != self.evidence.fg_state ): raise ValueError( - "log_potentials, ftov and evidence should be derived from the same fg_state." + "log_potentials, ftov_msgs and evidence should be derived from the same fg_state." ) @property diff --git a/pgmax/fg/transforms.py b/pgmax/fg/transforms.py new file mode 100644 index 00000000..99d251ca --- /dev/null +++ b/pgmax/fg/transforms.py @@ -0,0 +1,93 @@ +from dataclasses import replace +from typing import Any, Dict, Optional + +import jax +import jax.numpy as jnp + +from pgmax.bp import infer +from pgmax.fg import graph + + +def BP(bp_state: graph.BPState, num_iters: int): + @jax.jit + def run_bp( + log_potentials_updates: Optional[Dict[Any, jnp.ndarray]] = None, + ftov_msgs_updates: Optional[Dict[Any, jnp.ndarray]] = None, + evidence_updates: Optional[Dict[Any, jnp.ndarray]] = None, + damping: float = 0.5, + ): + """Function to perform belief propagation. + + Specifically, belief propagation is run for num_iters iterations and + returns the resulting messages. + + Args: + num_iters: The number of iterations for which to perform message passing + damping: The damping factor to use for message updates between one timestep and the next + bp_state: Initial messages to start the belief propagation. + + Returns: + ftov messages after running BP for num_iters iterations + """ + # Retrieve the necessary data structures from the compiled self.wiring and + # convert these to jax arrays. + log_potentials = jax.device_put(bp_state.log_potentials.value) + if log_potentials_updates is not None: + log_potentials = graph.update_log_potentials( + log_potentials, log_potentials_updates, bp_state.fg_state + ) + + ftov_msgs = jax.device_put(bp_state.ftov_msgs.value) + if ftov_msgs_updates is not None: + ftov_msgs = graph.update_ftov_msgs( + ftov_msgs, ftov_msgs_updates, bp_state.fg_state + ) + + evidence = jax.device_put(bp_state.evidence.value) + if evidence_updates is not None: + evidence = graph.update_evidence( + evidence, evidence_updates, bp_state.fg_state + ) + + wiring = jax.device_put(bp_state.fg_state.wiring) + max_msg_size = int(jnp.max(wiring.edges_num_states)) + # Normalize the messages to ensure the maximum value is 0. + ftov_msgs = infer.normalize_and_clip_msgs( + ftov_msgs, wiring.edges_num_states, max_msg_size + ) + num_val_configs = int(wiring.factor_configs_edge_states[-1, 0]) + 1 + + def update(msgs, _): + # Compute new variable to factor messages by message passing + vtof_msgs = infer.pass_var_to_fac_messages( + msgs, + evidence, + wiring.var_states_for_edges, + ) + # Compute new factor to variable messages by message passing + ftov_msgs = infer.pass_fac_to_var_messages( + vtof_msgs, + wiring.factor_configs_edge_states, + log_potentials, + num_val_configs, + ) + # Use the results of message passing to perform damping and + # update the factor to variable messages + delta_msgs = ftov_msgs - msgs + msgs = msgs + (1 - damping) * delta_msgs + # Normalize and clip these damped, updated messages before returning + # them. + msgs = infer.normalize_and_clip_msgs( + msgs, + wiring.edges_num_states, + max_msg_size, + ) + return msgs, None + + ftov_msgs, _ = jax.lax.scan(update, ftov_msgs, None, num_iters) + return replace( + bp_state, + ftov_msgs=graph.FToVMessages( + fg_state=bp_state.ftov_msgs.fg_state, value=ftov_msgs + ), + ) From f02308246e2bf563df5676ebe55eee9ceda2ef5c Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 23:20:23 -0700 Subject: [PATCH 17/56] Remove old implementation --- pgmax/fg/graph.py | 72 +---------------------------------------------- 1 file changed, 1 insertion(+), 71 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 1ccf0821..718078ab 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -5,7 +5,7 @@ import collections import copy import typing -from dataclasses import dataclass, replace +from dataclasses import dataclass from types import MappingProxyType from typing import ( Any, @@ -24,7 +24,6 @@ import jax.numpy as jnp import numpy as np -from pgmax.bp import infer from pgmax.fg import fg_utils, groups, nodes from pgmax.utils import cached_property @@ -267,75 +266,6 @@ def bp_state(self) -> BPState: evidence=Evidence(fg_state=self.fg_state), ) - def run_bp( - self, - num_iters: int, - damping_factor: float, - bp_state: BPState, - ) -> BPState: - """Function to perform belief propagation. - - Specifically, belief propagation is run for num_iters iterations and - returns the resulting messages. - - Args: - num_iters: The number of iterations for which to perform message passing - damping_factor: The damping factor to use for message updates between one timestep and the next - bp_state: Initial messages to start the belief propagation. - - Returns: - ftov messages after running BP for num_iters iterations - """ - # Retrieve the necessary data structures from the compiled self.wiring and - # convert these to jax arrays. - msgs = jax.device_put(bp_state.ftov_msgs.value) - evidence = jax.device_put(bp_state.evidence.value) - wiring = jax.device_put(bp_state.fg_state.wiring) - log_potentials = jax.device_put(bp_state.log_potentials.value) - max_msg_size = int(jnp.max(wiring.edges_num_states)) - - # Normalize the messages to ensure the maximum value is 0. - msgs = infer.normalize_and_clip_msgs( - msgs, wiring.edges_num_states, max_msg_size - ) - num_val_configs = int(wiring.factor_configs_edge_states[-1, 0]) + 1 - - @jax.jit - def message_passing_step(msgs, _): - # Compute new variable to factor messages by message passing - vtof_msgs = infer.pass_var_to_fac_messages( - msgs, - evidence, - wiring.var_states_for_edges, - ) - # Compute new factor to variable messages by message passing - ftov_msgs = infer.pass_fac_to_var_messages( - vtof_msgs, - wiring.factor_configs_edge_states, - log_potentials, - num_val_configs, - ) - # Use the results of message passing to perform damping and - # update the factor to variable messages - delta_msgs = ftov_msgs - msgs - msgs = msgs + (1 - damping_factor) * delta_msgs - # Normalize and clip these damped, updated messages before returning - # them. - msgs = infer.normalize_and_clip_msgs( - msgs, - wiring.edges_num_states, - max_msg_size, - ) - return msgs, None - - msgs_after_bp, _ = jax.lax.scan(message_passing_step, msgs, None, num_iters) - return replace( - bp_state, - ftov_msgs=FToVMessages( - fg_state=bp_state.ftov_msgs.fg_state, value=msgs_after_bp - ), - ) - def decode_map_states(self, bp_state: BPState) -> Dict[Tuple[Any, ...], int]: """Function to computes the output of MAP inference on input messages. From b00345f0ffbadc8c8947f0964da5975f28b80fd6 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 23:26:05 -0700 Subject: [PATCH 18/56] Use functions for setitem --- pgmax/fg/graph.py | 135 ++++++++-------------------------------------- 1 file changed, 21 insertions(+), 114 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 718078ab..9dff0d7c 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -428,40 +428,13 @@ def __setitem__( key: Any, data: Union[np.ndarray, jnp.ndarray], ): - if key in self.fg_state.named_factor_groups: - factor_group = self.fg_state.named_factor_groups[key] - if data.shape != factor_group.factor_group_log_potentials.shape: - raise ValueError( - f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " - f"for factor group {key}. Got {data.shape}." - ) - - start = self.fg_state.factor_group_to_potentials_starts[factor_group] - object.__setattr__( - self, - "value", - jax.device_put(self.value) - .at[start : start + factor_group.factor_group_log_potentials.shape[0]] - .set(data), - ) - elif frozenset(key) in self.fg_state.variables_to_factors: - factor = self.fg_state.variables_to_factors[frozenset(key)] - if data.shape != factor.log_potentials.shape: - raise ValueError( - f"Expected log potentials shape {factor.log_potentials.shape} " - f"for factor {key}. Got {data.shape}." - ) - - start = self.fg_state.factor_to_potentials_starts[factor] - object.__setattr__( - self, - "value", - jax.device_put(self.value) - .at[start : start + factor.log_potentials.shape[0]] - .set(data), - ) - else: - raise ValueError("") + object.__setattr__( + self, + "value", + update_log_potentials( + jax.device_put(self.value), {key: jax.device_put(data)}, self.fg_state + ), + ) @jax.partial(jax.jit, static_argnames="fg_state") @@ -608,58 +581,13 @@ def __setitem__( """ def __setitem__(self, keys, data) -> None: - if ( - isinstance(keys, tuple) - and len(keys) == 2 - and keys[1] in self.fg_state.variable_group.keys - ): - factor = self.fg_state.variables_to_factors[frozenset(keys[0])] - variable = self.fg_state.variable_group[keys[1]] - start = self.fg_state.factor_to_msgs_starts[factor] + np.sum( - factor.edges_num_states[: factor.variables.index(variable)] - ) - if data.shape != (variable.num_states,): - raise ValueError( - f"Given message shape {data.shape} does not match expected " - f"shape f{(variable.num_states,)} from factor {keys[0]} " - f"to variable {keys[1]}." - ) - - object.__setattr__( - self, - "value", - jax.device_put(self.value) - .at[start : start + variable.num_states] - .set(data), - ) - elif keys in self.fg_state.variable_group.keys: - variable = self.fg_state.variable_group[keys] - if data.shape != (variable.num_states,): - raise ValueError( - f"Given belief shape {data.shape} does not match expected " - f"shape f{(variable.num_states,)} for variable {keys}." - ) - - starts = np.nonzero( - self.fg_state.wiring.var_states_for_edges - == self.fg_state.vars_to_starts[variable] - )[0] - for start in starts: - object.__setattr__( - self, - "value", - jax.device_put(self.value) - .at[start : start + variable.num_states] - .st(data / starts.shape[0]), - ) - else: - raise ValueError( - "Invalid keys for setting messages. " - "Supported keys include a tuple of length 2 with factor " - "and variable keys for directly setting factor to variable " - "messages, or a valid variable key for spreading expected " - "beliefs at a variable" - ) + object.__setattr__( + self, + "value", + update_ftov_msgs( + jax.device_put(self.value), {keys: jax.device_put(data)}, self.fg_state + ), + ) @jax.partial(jax.jit, static_argnames="fg_state") @@ -753,31 +681,10 @@ def __setitem__( Note that each np.ndarray in the dictionary values must have the same size as variable_group.variable_size. """ - if key in self.fg_state.variable_group.container_keys: - if key == slice(None): - variable_group = self.fg_state.variable_group - else: - assert isinstance( - self.fg_state.variable_group, groups.CompositeVariableGroup - ) - variable_group = self.fg_state.variable_group[key] - - for var, evidence_val in variable_group.get_vars_to_evidence(data).items(): - start_index = self.fg_state.vars_to_starts[var] - object.__setattr__( - self, - "value", - jax.device_put(self.value) - .at[start_index : start_index + evidence_val.shape[0]] - .set(evidence_val), - ) - else: - var = self.fg_state.variable_group[key] - start_index = self.fg_state.vars_to_starts[var] - object.__setattr__( - self, - "value", - jax.device_put(self.value) - .at[start_index : start_index + var.num_states] - .set(data), - ) + object.__setattr__( + self, + "value", + update_evidence( + jax.device_put(self.value), {key: jax.device_put(data)}, self.fg_state + ), + ) From ab00866fc282db979cbb50f1bf17b044336640b7 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 24 Oct 2021 23:34:20 -0700 Subject: [PATCH 19/56] Implement decode map states --- pgmax/fg/graph.py | 29 ----------------------------- pgmax/fg/transforms.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 9dff0d7c..1bc3c43a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -266,35 +266,6 @@ def bp_state(self) -> BPState: evidence=Evidence(fg_state=self.fg_state), ) - def decode_map_states(self, bp_state: BPState) -> Dict[Tuple[Any, ...], int]: - """Function to computes the output of MAP inference on input messages. - - The final states are computed based on evidence obtained from the self.get_evidence - method as well as the internal wiring. - - Args: - bp_state: ftov messages for deciding MAP states - - Returns: - a dictionary mapping each variable key to the MAP states of the corresponding variable - """ - var_states_for_edges = jax.device_put( - bp_state.fg_state.wiring.var_states_for_edges - ) - evidence = jax.device_put(bp_state.evidence.value) - final_var_states = evidence.at[var_states_for_edges].add( - bp_state.ftov_msgs.value - ) - var_key_to_map_dict: Dict[Tuple[Any, ...], int] = {} - for var_key in bp_state.ftov_msgs.fg_state.variable_group.keys: - var = bp_state.ftov_msgs.fg_state.variable_group[var_key] - start_index = bp_state.ftov_msgs.fg_state.vars_to_starts[var] - var_key_to_map_dict[var_key] = int( - jnp.argmax(final_var_states[start_index : start_index + var.num_states]) - ) - - return var_key_to_map_dict - @dataclass(frozen=True, eq=False) class FactorGraphState: diff --git a/pgmax/fg/transforms.py b/pgmax/fg/transforms.py index 99d251ca..d077a765 100644 --- a/pgmax/fg/transforms.py +++ b/pgmax/fg/transforms.py @@ -1,5 +1,5 @@ from dataclasses import replace -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -91,3 +91,31 @@ def update(msgs, _): fg_state=bp_state.ftov_msgs.fg_state, value=ftov_msgs ), ) + + +def DecodeMAPState(bp_state: graph.BPState): + @jax.jit + def decode_map_state( + variable_name: Any = None, + ) -> Union[int, Dict[Tuple[Any, ...], int]]: + var_states_for_edges = jax.device_put( + bp_state.fg_state.wiring.var_states_for_edges + ) + evidence = jax.device_put(bp_state.evidence.value) + beliefs = evidence.at[var_states_for_edges].add(bp_state.ftov_msgs.value) + if variable_name is None: + variables_to_map_states: Dict[Tuple[Any, ...], int] = {} + for variable_name in bp_state.ftov_msgs.fg_state.variable_group.keys: + variable = bp_state.ftov_msgs.fg_state.variable_group[variable_name] + start_index = bp_state.ftov_msgs.fg_state.vars_to_starts[variable] + variables_to_map_states[variable_name] = int( + jnp.argmax(beliefs[start_index : start_index + variable.num_states]) + ) + + return variables_to_map_states + else: + variable = bp_state.ftov_msgs.fg_state.variable_group[variable_name] + start_index = bp_state.ftov_msgs.fg_state.vars_to_starts[variable] + return int( + jnp.argmax(beliefs[start_index : start_index + variable.num_states]) + ) From 28f7696a9e15937306ab115ef597fed03b2c7388 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 00:04:42 -0700 Subject: [PATCH 20/56] Make Ising model example run again --- pgmax/fg/graph.py | 6 +++--- pgmax/fg/groups.py | 4 ++-- pgmax/fg/transforms.py | 19 ++++++++++++++----- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 1bc3c43a..83161cd3 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -102,7 +102,7 @@ def __post_init__(self): self._factor_group_to_potentials_starts: OrderedDict[ groups.FactorGroup, int ] = collections.OrderedDict() - self._factor_to_potentials_starts = OrderedDict[ + self._factor_to_potentials_starts: OrderedDict[ nodes.EnumerationFactor, int ] = collections.OrderedDict() @@ -287,7 +287,7 @@ def __post_init__(self): getattr(self, field).flags.writeable = False if isinstance(getattr(self, field), Mapping): - object.__setattr__(self, field, MappingProxyType(self.field)) + object.__setattr__(self, field, MappingProxyType(getattr(self, field))) @dataclass(frozen=True, eq=False) @@ -568,7 +568,7 @@ def update_evidence( for key in updates: data = updates[key] if key in fg_state.variable_group.container_keys: - if key == slice(None): + if key is None: variable_group = fg_state.variable_group else: assert isinstance( diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index d0cfa510..8fb230ae 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -125,10 +125,10 @@ def variables(self) -> Tuple[nodes.Variable, ...]: @cached_property def container_keys(self) -> Tuple: - """Placeholder function. Returns a tuple containing slice(None) for all variable groups + """Placeholder function. Returns a tuple containing None for all variable groups other than a composite variable group """ - return (slice(None),) + return (None,) @dataclass(frozen=True, eq=False) diff --git a/pgmax/fg/transforms.py b/pgmax/fg/transforms.py index d077a765..d91d0523 100644 --- a/pgmax/fg/transforms.py +++ b/pgmax/fg/transforms.py @@ -9,6 +9,11 @@ def BP(bp_state: graph.BPState, num_iters: int): + max_msg_size = int(jnp.max(bp_state.fg_state.wiring.edges_num_states)) + num_val_configs = ( + int(bp_state.fg_state.wiring.factor_configs_edge_states[-1, 0]) + 1 + ) + @jax.jit def run_bp( log_potentials_updates: Optional[Dict[Any, jnp.ndarray]] = None, @@ -50,12 +55,10 @@ def run_bp( ) wiring = jax.device_put(bp_state.fg_state.wiring) - max_msg_size = int(jnp.max(wiring.edges_num_states)) # Normalize the messages to ensure the maximum value is 0. ftov_msgs = infer.normalize_and_clip_msgs( ftov_msgs, wiring.edges_num_states, max_msg_size ) - num_val_configs = int(wiring.factor_configs_edge_states[-1, 0]) + 1 def update(msgs, _): # Compute new variable to factor messages by message passing @@ -85,6 +88,9 @@ def update(msgs, _): return msgs, None ftov_msgs, _ = jax.lax.scan(update, ftov_msgs, None, num_iters) + return ftov_msgs + + def get_bp_state(ftov_msgs): return replace( bp_state, ftov_msgs=graph.FToVMessages( @@ -92,10 +98,11 @@ def update(msgs, _): ), ) + return run_bp, get_bp_state -def DecodeMAPState(bp_state: graph.BPState): - @jax.jit - def decode_map_state( + +def DecodeMAPStates(bp_state: graph.BPState): + def decode_map_states( variable_name: Any = None, ) -> Union[int, Dict[Tuple[Any, ...], int]]: var_states_for_edges = jax.device_put( @@ -119,3 +126,5 @@ def decode_map_state( return int( jnp.argmax(beliefs[start_index : start_index + variable.num_states]) ) + + return decode_map_states From d6acb08c1eae9be28fe198737fb34c2cb5301562 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 11:13:08 -0700 Subject: [PATCH 21/56] Updated ising model notebook --- examples/ising_model.py | 38 ++++++++++++++++++++++++-------------- pgmax/fg/graph.py | 9 ++++++++- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/examples/ising_model.py b/examples/ising_model.py index 9b58b37f..04f4d3db 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -15,17 +15,18 @@ # %% # %matplotlib inline +import jax import matplotlib.pyplot as plt import numpy as np -from pgmax.fg import graph, groups +from pgmax.fg import graph, groups, transforms # %% [markdown] # ### Construct variable grid, initialize factor graph, and add factors # %% variables = groups.NDVariableArray(variable_size=2, shape=(50, 50)) -fg = graph.FactorGraph(variables=variables, evidence_default_mode="random") +fg = graph.FactorGraph(variables=variables) connected_var_keys = [] for ii in range(50): for jj in range(50): @@ -45,8 +46,17 @@ # ### Run inference and visualize results # %% -msgs = fg.run_bp(3000, 0.5) -map_states = fg.decode_map_states(msgs) +run_bp, get_bp_state = transforms.BP(fg.bp_state, 3000) + +# %% +ftov_msgs = run_bp( + evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} +) +bp_state = get_bp_state(ftov_msgs) + +# %% +decode_map_states = transforms.DecodeMAPStates(bp_state) +map_states = decode_map_states() img = np.zeros((50, 50)) for key in map_states: img[key] = map_states[key] @@ -59,29 +69,29 @@ # %% # Query evidence for variable (0, 0) -msgs.evidence[0, 0] +bp_state.evidence[0, 0] # %% # Set evidence for variable (0, 0) -msgs.evidence[0, 0] = np.array([1.0, 1.0]) -msgs.evidence[0, 0] +bp_state.evidence[0, 0] = np.array([1.0, 1.0]) +bp_state.evidence[0, 0] # %% # Set evidence for all variables using an array evidence = np.random.randn(50, 50, 2) -msgs.evidence[:] = evidence -msgs.evidence[10, 10] == evidence[10, 10] +bp_state.evidence[None] = evidence +bp_state.evidence[10, 10] == evidence[10, 10] # %% # Query messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0) -msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] +bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)] # %% # Set messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0) -msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] = np.array([1.0, 1.0]) -msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] +bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)] = np.array([1.0, 1.0]) +bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)] # %% # Uniformly spread expected belief at a variable to all connected factors -msgs.ftov[0, 0] = np.array([1.0, 1.0]) -msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] +bp_state.ftov_msgs[0, 0] = np.array([1.0, 1.0]) +bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)] diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 83161cd3..6124a186 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -445,7 +445,7 @@ def update_ftov_msgs( == fg_state.vars_to_starts[variable] )[0] for start in starts: - ftov_msgs = ftov_msgs.at[start : start + variable.num_states].st( + ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set( data / starts.shape[0] ) else: @@ -552,6 +552,13 @@ def __setitem__( """ def __setitem__(self, keys, data) -> None: + if ( + isinstance(keys, tuple) + and len(keys) == 2 + and keys[1] in self.fg_state.variable_group.keys + ): + keys = (frozenset(keys[0]), keys[1]) + object.__setattr__( self, "value", From b590a5e6045bec987cd672c96cde0155e057854f Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 11:30:55 -0700 Subject: [PATCH 22/56] Make RBM example run again --- examples/rbm.py | 27 ++++++++++++++++----------- pgmax/fg/graph.py | 2 +- pgmax/fg/groups.py | 6 +++--- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/examples/rbm.py b/examples/rbm.py index 0a0ce86d..7e3d29e4 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -20,7 +20,7 @@ import matplotlib.pyplot as plt import numpy as np -from pgmax.fg import graph, groups +from pgmax.fg import graph, groups, transforms # %% # Load parameters @@ -47,19 +47,24 @@ ) # %% -# Set evidence -init_msgs = fg.get_init_msgs() -init_msgs.evidence["hidden"] = np.stack( - [np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1 -) -init_msgs.evidence["visible"] = np.stack( - [np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1 -) +run_bp, get_bp_state = transforms.BP(fg.bp_state, 100) # %% # Run inference and decode -msgs = fg.run_bp(100, 0.5, init_msgs) -map_states = fg.decode_map_states(msgs) +bp_state = get_bp_state( + run_bp( + evidence_updates={ + "hidden": np.stack( + [np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1 + ), + "visible": np.stack( + [np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1 + ), + } + ) +) +decode_map_states = transforms.DecodeMAPStates(bp_state) +map_states = decode_map_states() # %% # Visualize decodings diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 6124a186..c4725abc 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -581,7 +581,7 @@ def update_evidence( assert isinstance( fg_state.variable_group, groups.CompositeVariableGroup ) - variable_group = fg_state.variable_group[key] + variable_group = fg_state.variable_group.variable_group_container[key] for var, evidence_val in variable_group.get_vars_to_evidence(data).items(): start_index = fg_state.vars_to_starts[var] diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 8fb230ae..9b124516 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -519,10 +519,10 @@ def _get_variables_to_factors( log_potentials = np.zeros((num_factors, num_val_configs), dtype=float) else: if self.log_potentials.shape != ( - self.factor_configs.shape[0], - ) or self.log_potentials.shape != ( + num_val_configs, + ) and self.log_potentials.shape != ( num_factors, - self.factor_configs.shape[0], + num_val_configs, ): raise ValueError( f"Expected log potentials shape: {(num_val_configs,)} or {(num_factors, num_val_configs)}. " From 0189152ceb768856c2911624c806b090772fd0bd Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 12:52:02 -0700 Subject: [PATCH 23/56] Implement flatten/unflatten for variable groups --- pgmax/fg/groups.py | 179 ++++++++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 69 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 9b124516..eb14823f 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -18,6 +18,7 @@ Union, ) +import jax.numpy as jnp import numpy as np import pgmax.fg.nodes as nodes @@ -95,16 +96,6 @@ def _get_keys_to_vars(self) -> OrderedDict[Any, nodes.Variable]: "Please subclass the VariableGroup class and override this method" ) - def get_vars_to_evidence(self, evidence: Any) -> Dict[nodes.Variable, np.ndarray]: - """Function that turns input evidence into a dictionary mapping variables to evidence. - - Returns: - a dictionary mapping all possible variables to the corresponding evidence - """ - raise NotImplementedError( - "Please subclass the VariableGroup class and override this method" - ) - @cached_property def keys(self) -> Tuple[Any, ...]: """Function to return a tuple of all keys in the group. @@ -130,6 +121,16 @@ def container_keys(self) -> Tuple: """ return (None,) + def flatten(self, data: Any) -> np.ndarray: + raise NotImplementedError( + "Please subclass the VariableGroup class and override this method" + ) + + def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: + raise NotImplementedError( + "Please subclass the VariableGroup class and override this method" + ) + @dataclass(frozen=True, eq=False) class CompositeVariableGroup(VariableGroup): @@ -221,25 +222,57 @@ def _get_keys_to_vars(self) -> OrderedDict[Hashable, nodes.Variable]: return keys_to_vars - def get_vars_to_evidence( - self, evidence: Union[Mapping, Sequence] - ) -> Dict[nodes.Variable, np.ndarray]: - """Function that turns input evidence into a dictionary mapping variables to evidence. + def flatten(self, data: Union[Mapping, Sequence]) -> np.ndarray: + flat_data = np.concatenate( + [ + self.variable_group_container[key].flatten(data[key]) + for key in self.container_keys + ] + ) + return flat_data - Args: - evidence: A mapping or a sequence of evidences. - The type of evidence should match that of self.variable_group_container. + def unflatten( + self, flat_data: Union[np.ndarray, jnp.ndarray] + ) -> Union[Mapping, Sequence]: + if flat_data.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + ) - Returns: - a dictionary mapping all possible variables to the corresponding evidence - """ - vars_to_evidence: Dict[nodes.Variable, np.ndarray] = {} + num_variables = 0 + num_variable_states = 0 for key in self.container_keys: - vars_to_evidence.update( - self.variable_group_container[key].get_vars_to_evidence(evidence[key]) + variable_group = self.variable_group_container[key] + num_variables += len(variable_group.variables) + num_variable_states += ( + len(variable_group.variables) * variable_group.variables[0].num_states ) - return vars_to_evidence + if flat_data.shape[0] == num_variables: + use_num_states = False + elif flat_data.shape[0] == num_variable_states: + use_num_states = True + else: + raise ValueError( + f"flat_data should either be of shape (num_variables={len(self.variables)},), " + f"or (num_variable_states={num_variable_states},). " + f"Got {flat_data.shape}" + ) + + data: List[np.ndarray] = [] + start = 0 + for key in self.container_keys: + variable_group = self.variable_group_container[key] + length = len(variable_group.variables) + if use_num_states: + length *= variable_group.variables[0].num_states + + data.append(variable_group.unflatten(flat_data[start : start + length])) + start += length + if isinstance(self.variable_group_container, Mapping): + return dict([(key, data[kk]) for kk, key in enumerate(self.container_keys)]) + else: + return data @cached_property def container_keys(self) -> Tuple: @@ -290,30 +323,29 @@ def _get_keys_to_vars( return keys_to_vars - def get_vars_to_evidence( - self, evidence: np.ndarray - ) -> Dict[nodes.Variable, np.ndarray]: - """Function that turns input evidence into a dictionary mapping variables to evidence. - - Args: - evidence: An array of shape self.shape + (variable_size,) - An array containing evidence for all the variables + def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: + if data.shape != self.shape and data.shape != self.shape + ( + self.variable_size, + ): + raise ValueError( + f"data should be of shape {self.shape} or {self.shape + (self.variable_size,)}. " + f"Got {data.shape}." + ) - Returns: - a dictionary mapping all possible variables to the corresponding evidence + return data.flatten() - Raises: - ValueError: if input evidence array is of the wrong shape - """ - expected_shape = self.shape + (self.variable_size,) - if not evidence.shape == expected_shape: + def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: + if flat_data.size == np.product(self.shape): + data = flat_data.reshape(self.shape).copy() + elif flat_data.size == np.product(self.shape) * self.variable_size: + data = flat_data.reshape(self.shape + (self.variable_size,)).copy() + else: raise ValueError( - f"Input evidence should be an array of shape {expected_shape}. " - f"Got {evidence.shape}." + f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.variable_size,)}. " + f"Got {flat_data.shape}." ) - vars_to_evidence = {self._keys_to_vars[self.keys[0]]: evidence.ravel()} - return vars_to_evidence + return data @dataclass(frozen=True, eq=False) @@ -343,39 +375,48 @@ def _get_keys_to_vars(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: return keys_to_vars - def get_vars_to_evidence( - self, evidence: Mapping[Hashable, np.ndarray] - ) -> Dict[nodes.Variable, np.ndarray]: - """Function that turns input evidence into a dictionary mapping variables to evidence. - - Args: - evidence: A mapping from keys to np.ndarrays of evidence for that particular - key - - Returns: - a dictionary mapping all possible variables to the corresponding evidence - - Raises: - ValueError: if a key has not previously been added to this VariableGroup, or - if any evidence array is of the wrong shape. - """ - vars_to_evidence = {} - for key in evidence: + def flatten(self, data: Mapping[Hashable, np.ndarray]) -> np.ndarray: + for key in data: if key not in self._keys_to_vars: - raise ValueError( - f"The evidence is referring to a non-existent variable {key}." - ) + raise ValueError(f"data is referring to a non-existent variable {key}.") - if evidence[key].shape != (self.variable_size,): + if data[key].shape != (self.variable_size,): raise ValueError( - f"Variable {key} expects an evidence array of shape " + f"Variable {key} expects an data array of shape " f"({(self.variable_size,)})." - f"Got {evidence[key].shape}." + f"Got {data[key].shape}." ) - vars_to_evidence[self._keys_to_vars[key]] = evidence[key] + flat_data = np.concatenate([data[key].flatten() for key in self.keys]) + return flat_data + + def unflatten( + self, flat_data: Union[np.ndarray, jnp.ndarray] + ) -> Dict[Hashable, np.ndarray]: + num_variables = len(self.variable_names) + num_variable_states = len(self.variable_names) * self.variable_size + if flat_data.shape[0] == num_variables: + use_num_states = False + elif flat_data.shape[0] == num_variable_states: + use_num_states = True + else: + raise ValueError( + f"flat_data should either be of shape (num_variables={len(self.variables)},), " + f"or (num_variable_states={num_variable_states},). " + f"Got {flat_data.shape}" + ) + + start = 0 + data = {} + for key in self.variable_names: + if use_num_states: + data[key] = flat_data[start : start + self.variable_size] + start += self.variable_size + else: + data[key] = flat_data[start] + start += 1 - return vars_to_evidence + return data @dataclass(frozen=True, eq=False) From 05651c9a8f5dd7de3eebb8ef5369e61b3ab4ad7a Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 12:54:50 -0700 Subject: [PATCH 24/56] Use flatten in evidence updates --- pgmax/fg/graph.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index c4725abc..45ae344a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -583,11 +583,11 @@ def update_evidence( ) variable_group = fg_state.variable_group.variable_group_container[key] - for var, evidence_val in variable_group.get_vars_to_evidence(data).items(): - start_index = fg_state.vars_to_starts[var] - evidence = evidence.at[ - start_index : start_index + evidence_val.shape[0] - ].set(evidence_val) + start_index = fg_state.vars_to_starts[variable_group.variables[0]] + flat_data = variable_group.flatten(data) + evidence = evidence.at[start_index : start_index + flat_data.shape[0]].set( + flat_data + ) else: var = fg_state.variable_group[key] start_index = fg_state.vars_to_starts[var] From 19c59db7f337f26dc252907ac03f0486ebc75351 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 13:19:21 -0700 Subject: [PATCH 25/56] flatten/unflatten for factor groups --- pgmax/fg/groups.py | 77 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 3 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index eb14823f..26551ab3 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -335,6 +335,11 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: return data.flatten() def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: + if flat_data.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + ) + if flat_data.size == np.product(self.shape): data = flat_data.reshape(self.shape).copy() elif flat_data.size == np.product(self.shape) * self.variable_size: @@ -393,6 +398,11 @@ def flatten(self, data: Mapping[Hashable, np.ndarray]) -> np.ndarray: def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] ) -> Dict[Hashable, np.ndarray]: + if flat_data.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + ) + num_variables = len(self.variable_names) num_variable_states = len(self.variable_names) * self.variable_size if flat_data.shape[0] == num_variables: @@ -521,6 +531,16 @@ def factor_num_states(self) -> np.ndarray: ) return factor_num_states + def flatten(self, data: np.ndarray) -> np.ndarray: + raise NotImplementedError( + "Please subclass the FactorGroup class and override this method" + ) + + def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: + raise NotImplementedError( + "Please subclass the FactorGroup class and override this method" + ) + @dataclass(frozen=True, eq=False) class EnumerationFactorGroup(FactorGroup): @@ -589,6 +609,12 @@ def _get_variables_to_factors( ) return variables_to_factors + def flatten(self, data: np.ndarray) -> np.ndarray: + pass + + def unflatten(self, flat_data: np.ndarray) -> np.ndarray: + pass + @dataclass(frozen=True, eq=False) class PairwiseFactorGroup(FactorGroup): @@ -663,9 +689,14 @@ def _get_variables_to_factors( "(with {self.log_potential_matrix.shape[-2:]} configurations)." ) - factor_configs = np.mgrid[ - : self.log_potential_matrix.shape[0], : self.log_potential_matrix.shape[1] - ].T.reshape((-1, 2)) + factor_configs = ( + np.mgrid[ + : self.log_potential_matrix.shape[0], + : self.log_potential_matrix.shape[1], + ] + .transpose((1, 2, 0)) + .reshape((-1, 2)) + ) log_potential_matrix = np.broadcast_to( self.log_potential_matrix, (len(self.connected_var_keys),) + self.log_potential_matrix.shape[-2:], @@ -686,3 +717,43 @@ def _get_variables_to_factors( ] ) return variables_to_factors + + def flatten(self, data: np.ndarray) -> np.ndarray: + num_factors = len(self.factors) + if data.shape != (num_factors,) + self.log_potential_matrix.shape[ + -2: + ] and data.shape != (num_factors, np.sum(self.log_potential_matrix.shape[-2:])): + raise ValueError( + f"data should be of shape {(num_factors,)} or " + f"{(num_factors, np.sum(self.log_potential_matrix.shape[-2:]))}. " + f"Got {data.shape}." + ) + + return data.flatten() + + def unflatten(self, flat_data: np.ndarray) -> np.ndarray: + if flat_data.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + ) + + num_factors = len(self.factors) + if flat_data.size == num_factors * np.product( + self.log_potential_matrix.shape[-2:] + ): + data = flat_data.reshape( + (num_factors,) + self.log_potential_matrix.shape[-2:] + ).copy() + elif flat_data.size == num_factors * np.sum( + self.log_potential_matrix.shape[-2:] + ): + data = flat_data.reshape( + (num_factors, np.sum(self.log_potential_matrix.shape[-2:])) + ).copy() + else: + raise ValueError( + f"flat_data should be compatible with shape {(num_factors,) + self.log_potential_matrix.shape[-2:]} " + f"or (num_factors, np.sum(self.log_potential_matrix.shape[-2:])). Got {flat_data.shape}." + ) + + return data From c824474566bc27c58f95114bb9813f912bb30d5e Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 13:22:18 -0700 Subject: [PATCH 26/56] Use flatten in log potentials updates --- pgmax/fg/graph.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 45ae344a..98ac4c2a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -328,16 +328,17 @@ def update_log_potentials( data = updates[key] if key in fg_state.named_factor_groups: factor_group = fg_state.named_factor_groups[key] - if data.shape != factor_group.factor_group_log_potentials.shape: + flat_data = factor_group.flatten(data) + if flat_data.shape != factor_group.factor_group_log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " - f"for factor group {key}. Got {data.shape}." + f"for factor group {key}. Got incompatible data shape {data.shape}." ) start = fg_state.factor_group_to_potentials_starts[factor_group] - log_potentials = log_potentials.at[ - start : start + factor_group.factor_group_log_potentials.shape[0] - ].set(data) + log_potentials = log_potentials.at[start : start + flat_data.shape[0]].set( + flat_data + ) elif frozenset(key) in fg_state.variables_to_factors: factor = fg_state.variables_to_factors[frozenset(key)] if data.shape != factor.log_potentials.shape: From 70c726cdac959d7f6c68baf23a2e24904a683954 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 13:36:04 -0700 Subject: [PATCH 27/56] Simplify decode map states --- examples/ising_model.py | 11 +--- examples/rbm.py | 13 ++-- pgmax/fg/graph.py | 113 +++++++++++++++++++++++++++++++++- pgmax/fg/transforms.py | 130 ---------------------------------------- 4 files changed, 117 insertions(+), 150 deletions(-) delete mode 100644 pgmax/fg/transforms.py diff --git a/examples/ising_model.py b/examples/ising_model.py index 04f4d3db..13d73482 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -19,7 +19,7 @@ import matplotlib.pyplot as plt import numpy as np -from pgmax.fg import graph, groups, transforms +from pgmax.fg import graph, groups # %% [markdown] # ### Construct variable grid, initialize factor graph, and add factors @@ -46,7 +46,7 @@ # ### Run inference and visualize results # %% -run_bp, get_bp_state = transforms.BP(fg.bp_state, 3000) +run_bp, get_bp_state = graph.BP(fg.bp_state, 3000) # %% ftov_msgs = run_bp( @@ -55,12 +55,7 @@ bp_state = get_bp_state(ftov_msgs) # %% -decode_map_states = transforms.DecodeMAPStates(bp_state) -map_states = decode_map_states() -img = np.zeros((50, 50)) -for key in map_states: - img[key] = map_states[key] - +img = graph.decode_map_states(bp_state) fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(img) diff --git a/examples/rbm.py b/examples/rbm.py index 7e3d29e4..5a2caf03 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -20,7 +20,7 @@ import matplotlib.pyplot as plt import numpy as np -from pgmax.fg import graph, groups, transforms +from pgmax.fg import graph, groups # %% # Load parameters @@ -47,7 +47,7 @@ ) # %% -run_bp, get_bp_state = transforms.BP(fg.bp_state, 100) +run_bp, get_bp_state = graph.BP(fg.bp_state, 100) # %% # Run inference and decode @@ -63,14 +63,9 @@ } ) ) -decode_map_states = transforms.DecodeMAPStates(bp_state) -map_states = decode_map_states() +map_states = graph.decode_map_states(bp_state) # %% # Visualize decodings -img = np.zeros(bv.shape) -for ii in range(nv): - img[ii] = map_states[("visible", ii)] - -img = img.reshape((28, 28)) +img = map_states["visible"].reshape((28, 28)) plt.imshow(img) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 98ac4c2a..d6a19409 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -1,11 +1,10 @@ """A module containing the core class to specify a Factor Graph.""" - from __future__ import annotations import collections import copy import typing -from dataclasses import dataclass +from dataclasses import dataclass, replace from types import MappingProxyType from typing import ( Any, @@ -24,7 +23,8 @@ import jax.numpy as jnp import numpy as np -from pgmax.fg import fg_utils, groups, nodes +from pgmax.bp import infer +from pgmax.fg import fg_utils, graph, groups, nodes from pgmax.utils import cached_property @@ -667,3 +667,110 @@ def __setitem__( jax.device_put(self.value), {key: jax.device_put(data)}, self.fg_state ), ) + + +def BP(bp_state: graph.BPState, num_iters: int): + max_msg_size = int(jnp.max(bp_state.fg_state.wiring.edges_num_states)) + num_val_configs = ( + int(bp_state.fg_state.wiring.factor_configs_edge_states[-1, 0]) + 1 + ) + + @jax.jit + def run_bp( + log_potentials_updates: Optional[Dict[Any, jnp.ndarray]] = None, + ftov_msgs_updates: Optional[Dict[Any, jnp.ndarray]] = None, + evidence_updates: Optional[Dict[Any, jnp.ndarray]] = None, + damping: float = 0.5, + ): + """Function to perform belief propagation. + + Specifically, belief propagation is run for num_iters iterations and + returns the resulting messages. + + Args: + num_iters: The number of iterations for which to perform message passing + damping: The damping factor to use for message updates between one timestep and the next + bp_state: Initial messages to start the belief propagation. + + Returns: + ftov messages after running BP for num_iters iterations + """ + # Retrieve the necessary data structures from the compiled self.wiring and + # convert these to jax arrays. + log_potentials = jax.device_put(bp_state.log_potentials.value) + if log_potentials_updates is not None: + log_potentials = graph.update_log_potentials( + log_potentials, log_potentials_updates, bp_state.fg_state + ) + + ftov_msgs = jax.device_put(bp_state.ftov_msgs.value) + if ftov_msgs_updates is not None: + ftov_msgs = graph.update_ftov_msgs( + ftov_msgs, ftov_msgs_updates, bp_state.fg_state + ) + + evidence = jax.device_put(bp_state.evidence.value) + if evidence_updates is not None: + evidence = graph.update_evidence( + evidence, evidence_updates, bp_state.fg_state + ) + + wiring = jax.device_put(bp_state.fg_state.wiring) + # Normalize the messages to ensure the maximum value is 0. + ftov_msgs = infer.normalize_and_clip_msgs( + ftov_msgs, wiring.edges_num_states, max_msg_size + ) + + def update(msgs, _): + # Compute new variable to factor messages by message passing + vtof_msgs = infer.pass_var_to_fac_messages( + msgs, + evidence, + wiring.var_states_for_edges, + ) + # Compute new factor to variable messages by message passing + ftov_msgs = infer.pass_fac_to_var_messages( + vtof_msgs, + wiring.factor_configs_edge_states, + log_potentials, + num_val_configs, + ) + # Use the results of message passing to perform damping and + # update the factor to variable messages + delta_msgs = ftov_msgs - msgs + msgs = msgs + (1 - damping) * delta_msgs + # Normalize and clip these damped, updated messages before returning + # them. + msgs = infer.normalize_and_clip_msgs( + msgs, + wiring.edges_num_states, + max_msg_size, + ) + return msgs, None + + ftov_msgs, _ = jax.lax.scan(update, ftov_msgs, None, num_iters) + return ftov_msgs + + def get_bp_state(ftov_msgs): + return replace( + bp_state, + ftov_msgs=graph.FToVMessages( + fg_state=bp_state.ftov_msgs.fg_state, value=ftov_msgs + ), + ) + + return run_bp, get_bp_state + + +def decode_map_states( + bp_state: graph.BPState, +): + var_states_for_edges = jax.device_put(bp_state.fg_state.wiring.var_states_for_edges) + evidence = jax.device_put(bp_state.evidence.value) + map_states = jax.tree_util.tree_map( + lambda x: jnp.argmax(x, axis=-1), + bp_state.fg_state.variable_group.unflatten( + evidence.at[var_states_for_edges].add(bp_state.ftov_msgs.value) + ), + ) + return map_states diff --git a/pgmax/fg/transforms.py b/pgmax/fg/transforms.py deleted file mode 100644 index d91d0523..00000000 --- a/pgmax/fg/transforms.py +++ /dev/null @@ -1,130 +0,0 @@ -from dataclasses import replace -from typing import Any, Dict, Optional, Tuple, Union - -import jax -import jax.numpy as jnp - -from pgmax.bp import infer -from pgmax.fg import graph - - -def BP(bp_state: graph.BPState, num_iters: int): - max_msg_size = int(jnp.max(bp_state.fg_state.wiring.edges_num_states)) - num_val_configs = ( - int(bp_state.fg_state.wiring.factor_configs_edge_states[-1, 0]) + 1 - ) - - @jax.jit - def run_bp( - log_potentials_updates: Optional[Dict[Any, jnp.ndarray]] = None, - ftov_msgs_updates: Optional[Dict[Any, jnp.ndarray]] = None, - evidence_updates: Optional[Dict[Any, jnp.ndarray]] = None, - damping: float = 0.5, - ): - """Function to perform belief propagation. - - Specifically, belief propagation is run for num_iters iterations and - returns the resulting messages. - - Args: - num_iters: The number of iterations for which to perform message passing - damping: The damping factor to use for message updates between one timestep and the next - bp_state: Initial messages to start the belief propagation. - - Returns: - ftov messages after running BP for num_iters iterations - """ - # Retrieve the necessary data structures from the compiled self.wiring and - # convert these to jax arrays. - log_potentials = jax.device_put(bp_state.log_potentials.value) - if log_potentials_updates is not None: - log_potentials = graph.update_log_potentials( - log_potentials, log_potentials_updates, bp_state.fg_state - ) - - ftov_msgs = jax.device_put(bp_state.ftov_msgs.value) - if ftov_msgs_updates is not None: - ftov_msgs = graph.update_ftov_msgs( - ftov_msgs, ftov_msgs_updates, bp_state.fg_state - ) - - evidence = jax.device_put(bp_state.evidence.value) - if evidence_updates is not None: - evidence = graph.update_evidence( - evidence, evidence_updates, bp_state.fg_state - ) - - wiring = jax.device_put(bp_state.fg_state.wiring) - # Normalize the messages to ensure the maximum value is 0. - ftov_msgs = infer.normalize_and_clip_msgs( - ftov_msgs, wiring.edges_num_states, max_msg_size - ) - - def update(msgs, _): - # Compute new variable to factor messages by message passing - vtof_msgs = infer.pass_var_to_fac_messages( - msgs, - evidence, - wiring.var_states_for_edges, - ) - # Compute new factor to variable messages by message passing - ftov_msgs = infer.pass_fac_to_var_messages( - vtof_msgs, - wiring.factor_configs_edge_states, - log_potentials, - num_val_configs, - ) - # Use the results of message passing to perform damping and - # update the factor to variable messages - delta_msgs = ftov_msgs - msgs - msgs = msgs + (1 - damping) * delta_msgs - # Normalize and clip these damped, updated messages before returning - # them. - msgs = infer.normalize_and_clip_msgs( - msgs, - wiring.edges_num_states, - max_msg_size, - ) - return msgs, None - - ftov_msgs, _ = jax.lax.scan(update, ftov_msgs, None, num_iters) - return ftov_msgs - - def get_bp_state(ftov_msgs): - return replace( - bp_state, - ftov_msgs=graph.FToVMessages( - fg_state=bp_state.ftov_msgs.fg_state, value=ftov_msgs - ), - ) - - return run_bp, get_bp_state - - -def DecodeMAPStates(bp_state: graph.BPState): - def decode_map_states( - variable_name: Any = None, - ) -> Union[int, Dict[Tuple[Any, ...], int]]: - var_states_for_edges = jax.device_put( - bp_state.fg_state.wiring.var_states_for_edges - ) - evidence = jax.device_put(bp_state.evidence.value) - beliefs = evidence.at[var_states_for_edges].add(bp_state.ftov_msgs.value) - if variable_name is None: - variables_to_map_states: Dict[Tuple[Any, ...], int] = {} - for variable_name in bp_state.ftov_msgs.fg_state.variable_group.keys: - variable = bp_state.ftov_msgs.fg_state.variable_group[variable_name] - start_index = bp_state.ftov_msgs.fg_state.vars_to_starts[variable] - variables_to_map_states[variable_name] = int( - jnp.argmax(beliefs[start_index : start_index + variable.num_states]) - ) - - return variables_to_map_states - else: - variable = bp_state.ftov_msgs.fg_state.variable_group[variable_name] - start_index = bp_state.ftov_msgs.fg_state.vars_to_starts[variable] - return int( - jnp.argmax(beliefs[start_index : start_index + variable.num_states]) - ) - - return decode_map_states From 0ca1a3de136a5bb8cbe04128ce4c35db71df60b5 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 14:19:12 -0700 Subject: [PATCH 28/56] Fix flatten/unflatten --- pgmax/fg/graph.py | 88 +++++++++++++++++++++++++++++++++------------- pgmax/fg/groups.py | 73 +++++++++++++++++++++++++++++++------- 2 files changed, 124 insertions(+), 37 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index d6a19409..030b0baa 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -4,7 +4,7 @@ import collections import copy import typing -from dataclasses import dataclass, replace +from dataclasses import asdict, dataclass from types import MappingProxyType from typing import ( Any, @@ -24,7 +24,7 @@ import numpy as np from pgmax.bp import infer -from pgmax.fg import fg_utils, graph, groups, nodes +from pgmax.fg import fg_utils, groups, nodes from pgmax.utils import cached_property @@ -669,11 +669,31 @@ def __setitem__( ) -def BP(bp_state: graph.BPState, num_iters: int): - max_msg_size = int(jnp.max(bp_state.fg_state.wiring.edges_num_states)) - num_val_configs = ( - int(bp_state.fg_state.wiring.factor_configs_edge_states[-1, 0]) + 1 - ) +@jax.tree_util.register_pytree_node_class +@dataclass(frozen=True, eq=False) +class BPArrays: + + log_potentials: Union[np.ndarray, jnp.ndarray] + ftov_msgs: Union[np.ndarray, jnp.ndarray] + evidence: Union[np.ndarray, jnp.ndarray] + + def __post_init__(self): + for field in self.__dataclass_fields__: + if isinstance(getattr(self, field), np.ndarray): + getattr(self, field).flags.writeable = False + + def tree_flatten(self): + return jax.tree_util.tree_flatten(asdict(self)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(**aux_data.unflatten(children)) + + +def BP(bp_state: BPState, num_iters: int): + wiring = jax.device_put(bp_state.fg_state.wiring) + max_msg_size = int(jnp.max(wiring.edges_num_states)) + num_val_configs = int(wiring.factor_configs_edge_states[-1, 0]) + 1 @jax.jit def run_bp( @@ -681,7 +701,7 @@ def run_bp( ftov_msgs_updates: Optional[Dict[Any, jnp.ndarray]] = None, evidence_updates: Optional[Dict[Any, jnp.ndarray]] = None, damping: float = 0.5, - ): + ) -> BPArrays: """Function to perform belief propagation. Specifically, belief propagation is run for num_iters iterations and @@ -699,23 +719,20 @@ def run_bp( # convert these to jax arrays. log_potentials = jax.device_put(bp_state.log_potentials.value) if log_potentials_updates is not None: - log_potentials = graph.update_log_potentials( + log_potentials = update_log_potentials( log_potentials, log_potentials_updates, bp_state.fg_state ) ftov_msgs = jax.device_put(bp_state.ftov_msgs.value) if ftov_msgs_updates is not None: - ftov_msgs = graph.update_ftov_msgs( + ftov_msgs = update_ftov_msgs( ftov_msgs, ftov_msgs_updates, bp_state.fg_state ) evidence = jax.device_put(bp_state.evidence.value) if evidence_updates is not None: - evidence = graph.update_evidence( - evidence, evidence_updates, bp_state.fg_state - ) + evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state) - wiring = jax.device_put(bp_state.fg_state.wiring) # Normalize the messages to ensure the maximum value is 0. ftov_msgs = infer.normalize_and_clip_msgs( ftov_msgs, wiring.edges_num_states, max_msg_size @@ -749,22 +766,45 @@ def update(msgs, _): return msgs, None ftov_msgs, _ = jax.lax.scan(update, ftov_msgs, None, num_iters) - return ftov_msgs + return BPArrays( + log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence + ) + + def get_bp_state(bp_arrays: BPArrays) -> BPState: + return BPState( + log_potentials=LogPotentials( + fg_state=bp_state.fg_state, value=bp_arrays.log_potentials + ), + ftov_msgs=FToVMessages( + fg_state=bp_state.fg_state, + value=bp_arrays.ftov_msgs, + ), + evidence=Evidence(fg_state=bp_state.fg_state, value=bp_arrays.evidence), + ) + + @jax.jit + def get_beliefs(bp_arrays: BPArrays): + evidence = jax.device_put(bp_arrays.evidence) + beliefs = bp_state.fg_state.variable_group.unflatten( + evidence.at[wiring.var_states_for_edges].add(bp_arrays.ftov_msgs) + ) + return beliefs - def get_bp_state(ftov_msgs): - return replace( - bp_state, - ftov_msgs=graph.FToVMessages( - fg_state=bp_state.ftov_msgs.fg_state, value=ftov_msgs + @jax.jit + def decode_map_states(bp_arrays: BPArrays): + evidence = jax.device_put(bp_arrays.evidence) + map_states = jax.tree_util.tree_map( + lambda x: jnp.argmax(x, axis=-1), + bp_state.fg_state.variable_group.unflatten( + evidence.at[wiring.var_states_for_edges].add(bp_arrays.ftov_msgs) ), ) + return map_states - return run_bp, get_bp_state + return run_bp, get_bp_state, get_beliefs, decode_map_states -def decode_map_states( - bp_state: graph.BPState, -): +def decode_map_states(bp_state: BPState): var_states_for_edges = jax.device_put(bp_state.fg_state.wiring.var_states_for_edges) evidence = jax.device_put(bp_state.evidence.value) map_states = jax.tree_util.tree_map( diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 26551ab3..24816f76 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -341,9 +341,11 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: ) if flat_data.size == np.product(self.shape): - data = flat_data.reshape(self.shape).copy() + data = np.array(flat_data.reshape(self.shape), copy=True) elif flat_data.size == np.product(self.shape) * self.variable_size: - data = flat_data.reshape(self.shape + (self.variable_size,)).copy() + data = np.array( + flat_data.reshape(self.shape + (self.variable_size,)), copy=True + ) else: raise ValueError( f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.variable_size,)}. " @@ -420,10 +422,12 @@ def unflatten( data = {} for key in self.variable_names: if use_num_states: - data[key] = flat_data[start : start + self.variable_size] + data[key] = np.array( + flat_data[start : start + self.variable_size], copy=True + ) start += self.variable_size else: - data[key] = flat_data[start] + data[key] = np.array(flat_data[start], copy=True) start += 1 return data @@ -610,10 +614,47 @@ def _get_variables_to_factors( return variables_to_factors def flatten(self, data: np.ndarray) -> np.ndarray: - pass + num_factors = len(self.factors) + if data.shape != (num_factors, self.factor_configs.shape[0]) and data.shape != ( + num_factors, + np.sum(self.factors[0].edges_num_states), + ): + raise ValueError( + f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or " + f"{( num_factors, np.sum(self.factors[0].edges_num_states))}. " + f"Got {data.shape}." + ) + + return data.flatten() def unflatten(self, flat_data: np.ndarray) -> np.ndarray: - pass + if flat_data.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + ) + + num_factors = len(self.factors) + if flat_data.size == num_factors * self.factor_configs.shape[0]: + data = np.array( + flat_data.reshape( + (num_factors, self.factor_configs.shape[0]), + copy=True, + ) + ) + elif flat_data.size == num_factors * np.sum(self.factors[0].edges_num_states): + data = np.array( + flat_data.reshape( + (num_factors, np.sum(self.factors[0].edges_num_states)) + ), + copy=True, + ) + else: + raise ValueError( + f"flat_data should be compatible with shape {(num_factors, self.factor_configs.shape[0])} " + f"or (num_factors, np.sum(self.factors[0].edges_num_states)). Got {flat_data.shape}." + ) + + return data @dataclass(frozen=True, eq=False) @@ -724,7 +765,7 @@ def flatten(self, data: np.ndarray) -> np.ndarray: -2: ] and data.shape != (num_factors, np.sum(self.log_potential_matrix.shape[-2:])): raise ValueError( - f"data should be of shape {(num_factors,)} or " + f"data should be of shape {(num_factors,) + self.log_potential_matrix.shape[-2:]} or " f"{(num_factors, np.sum(self.log_potential_matrix.shape[-2:]))}. " f"Got {data.shape}." ) @@ -741,15 +782,21 @@ def unflatten(self, flat_data: np.ndarray) -> np.ndarray: if flat_data.size == num_factors * np.product( self.log_potential_matrix.shape[-2:] ): - data = flat_data.reshape( - (num_factors,) + self.log_potential_matrix.shape[-2:] - ).copy() + data = np.array( + flat_data.reshape( + (num_factors,) + self.log_potential_matrix.shape[-2:] + ), + copy=True, + ) elif flat_data.size == num_factors * np.sum( self.log_potential_matrix.shape[-2:] ): - data = flat_data.reshape( - (num_factors, np.sum(self.log_potential_matrix.shape[-2:])) - ).copy() + data = np.array( + flat_data.reshape( + (num_factors, np.sum(self.log_potential_matrix.shape[-2:])) + ), + copy=True, + ) else: raise ValueError( f"flat_data should be compatible with shape {(num_factors,) + self.log_potential_matrix.shape[-2:]} " From eb035cf403d8b931604140e7e062ff6045295ca6 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 14:21:35 -0700 Subject: [PATCH 29/56] Get rid of copy in unflatten --- pgmax/fg/groups.py | 40 ++++++++++++---------------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 24816f76..6d1b067e 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -341,11 +341,9 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: ) if flat_data.size == np.product(self.shape): - data = np.array(flat_data.reshape(self.shape), copy=True) + data = flat_data.reshape(self.shape) elif flat_data.size == np.product(self.shape) * self.variable_size: - data = np.array( - flat_data.reshape(self.shape + (self.variable_size,)), copy=True - ) + data = flat_data.reshape(self.shape + (self.variable_size,)) else: raise ValueError( f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.variable_size,)}. " @@ -422,12 +420,10 @@ def unflatten( data = {} for key in self.variable_names: if use_num_states: - data[key] = np.array( - flat_data[start : start + self.variable_size], copy=True - ) + data[key] = flat_data[start : start + self.variable_size] start += self.variable_size else: - data[key] = np.array(flat_data[start], copy=True) + data[key] = flat_data[start] start += 1 return data @@ -635,18 +631,12 @@ def unflatten(self, flat_data: np.ndarray) -> np.ndarray: num_factors = len(self.factors) if flat_data.size == num_factors * self.factor_configs.shape[0]: - data = np.array( - flat_data.reshape( - (num_factors, self.factor_configs.shape[0]), - copy=True, - ) + data = flat_data.reshape( + (num_factors, self.factor_configs.shape[0]), ) elif flat_data.size == num_factors * np.sum(self.factors[0].edges_num_states): - data = np.array( - flat_data.reshape( - (num_factors, np.sum(self.factors[0].edges_num_states)) - ), - copy=True, + data = flat_data.reshape( + (num_factors, np.sum(self.factors[0].edges_num_states)) ) else: raise ValueError( @@ -782,20 +772,14 @@ def unflatten(self, flat_data: np.ndarray) -> np.ndarray: if flat_data.size == num_factors * np.product( self.log_potential_matrix.shape[-2:] ): - data = np.array( - flat_data.reshape( - (num_factors,) + self.log_potential_matrix.shape[-2:] - ), - copy=True, + data = flat_data.reshape( + (num_factors,) + self.log_potential_matrix.shape[-2:] ) elif flat_data.size == num_factors * np.sum( self.log_potential_matrix.shape[-2:] ): - data = np.array( - flat_data.reshape( - (num_factors, np.sum(self.log_potential_matrix.shape[-2:])) - ), - copy=True, + data = flat_data.reshape( + (num_factors, np.sum(self.log_potential_matrix.shape[-2:])) ) else: raise ValueError( From 98fae8f1571e7b8e05a15e5a71f03df9a01ee486 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 14:26:44 -0700 Subject: [PATCH 30/56] Update notebooks --- examples/ising_model.py | 8 ++++---- examples/rbm.py | 24 +++++++++++------------- pgmax/fg/graph.py | 12 ------------ 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/examples/ising_model.py b/examples/ising_model.py index 13d73482..5322b3ac 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -46,16 +46,16 @@ # ### Run inference and visualize results # %% -run_bp, get_bp_state = graph.BP(fg.bp_state, 3000) +bp_state = fg.bp_state +run_bp, _, _, decode_map_states = graph.BP(bp_state, 3000) # %% -ftov_msgs = run_bp( +bp_arrays = run_bp( evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} ) -bp_state = get_bp_state(ftov_msgs) # %% -img = graph.decode_map_states(bp_state) +img = decode_map_states(bp_arrays) fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(img) diff --git a/examples/rbm.py b/examples/rbm.py index 5a2caf03..539d41b4 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -47,23 +47,21 @@ ) # %% -run_bp, get_bp_state = graph.BP(fg.bp_state, 100) +run_bp, _, _, decode_map_states = graph.BP(fg.bp_state, 100) # %% # Run inference and decode -bp_state = get_bp_state( - run_bp( - evidence_updates={ - "hidden": np.stack( - [np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1 - ), - "visible": np.stack( - [np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1 - ), - } - ) +bp_arrays = run_bp( + evidence_updates={ + "hidden": np.stack( + [np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1 + ), + "visible": np.stack( + [np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1 + ), + } ) -map_states = graph.decode_map_states(bp_state) +map_states = decode_map_states(bp_arrays) # %% # Visualize decodings diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 030b0baa..bf1dc0d8 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -802,15 +802,3 @@ def decode_map_states(bp_arrays: BPArrays): return map_states return run_bp, get_bp_state, get_beliefs, decode_map_states - - -def decode_map_states(bp_state: BPState): - var_states_for_edges = jax.device_put(bp_state.fg_state.wiring.var_states_for_edges) - evidence = jax.device_put(bp_state.evidence.value) - map_states = jax.tree_util.tree_map( - lambda x: jnp.argmax(x, axis=-1), - bp_state.fg_state.variable_group.unflatten( - evidence.at[var_states_for_edges].add(bp_state.ftov_msgs.value) - ), - ) - return map_states From 9af46d1ebca2e03c0b876e5e45bf2f89d91b6d54 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 14:54:27 -0700 Subject: [PATCH 31/56] Fix all notebooks --- examples/heretic_example.py | 48 ++++++++++++-------------------- examples/sanity_check_example.py | 19 ++++++------- pgmax/fg/groups.py | 45 +++++++++++++++++++----------- 3 files changed, 55 insertions(+), 57 deletions(-) diff --git a/examples/heretic_example.py b/examples/heretic_example.py index e9d64217..ce4c2629 100644 --- a/examples/heretic_example.py +++ b/examples/heretic_example.py @@ -14,15 +14,14 @@ # --- # %% +# %matplotlib inline +# Standard Package Imports +from dataclasses import replace from timeit import default_timer as timer from typing import Any, List, Tuple import jax import jax.numpy as jnp - -# %% -# %matplotlib inline -# Standard Package Imports import matplotlib.pyplot as plt import numpy as np @@ -122,12 +121,11 @@ def binary_connected_variables( log_potential_matrix=W_pot[:, :, k_row, k_col], ) + # %% [markdown] # # Construct Initial Messages # %% - - def custom_flatten_ordering(Mdown, Mup): flat_idx = 0 flat_Mdown = Mdown.flatten() @@ -177,30 +175,27 @@ def custom_flatten_ordering(Mdown, Mup): # %% tags=[] # Run BP -init_msgs = fg.get_init_msgs() -init_msgs.ftov = graph.FToVMessages( - factor_graph=fg, - init_value=jax.device_put( - custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup)) +bp_state = replace( + fg.bp_state, + ftov_msgs=graph.FToVMessages( + fg_state=fg.fg_state, + value=jax.device_put( + custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup)) + ), ), ) -init_msgs.evidence[0] = np.array(bXn_evidence) -init_msgs.evidence[1] = np.array(bHn_evidence) +bp_state.evidence[0] = np.array(bXn_evidence) +bp_state.evidence[1] = np.array(bHn_evidence) +run_bp, _, _, decode_map_states = graph.BP(bp_state, 500) bp_start_time = timer() # Assign evidence to pixel vars -final_msgs = fg.run_bp( - 500, - 0.5, - init_msgs=init_msgs, -) +bp_arrays = run_bp() bp_end_time = timer() print(f"time taken for bp {bp_end_time - bp_start_time}") # Run inference and convert result to human-readable data structure data_writeback_start_time = timer() -map_message_dict = fg.decode_map_states( - final_msgs, -) +map_states = decode_map_states(bp_arrays) data_writeback_end_time = timer() print( f"time taken for data conversion of inference result {data_writeback_end_time - data_writeback_start_time}" @@ -236,13 +231,6 @@ def plot_images(images): # %% -img_arr = np.zeros((1, im_size[0], im_size[1])) - -for row in range(im_size[0]): - for col in range(im_size[1]): - img_val = float(map_message_dict[0, row, col]) - if img_val == 2.0: - img_val = 0.4 - img_arr[0, row, col] = img_val * 1.0 - +img_arr = map_states[0][None].copy().astype(float) +img_arr[img_arr == 2.0] = 0.4 plot_images(img_arr) diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index f8862efa..0d9ffa65 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -330,17 +330,18 @@ def create_valid_suppression_config_arr(suppression_diameter): # %% # Run BP # Set the evidence -init_msgs = fg.get_init_msgs() -init_msgs.evidence["grid_vars"] = grid_evidence_arr -init_msgs.evidence["additional_vars"] = additional_vars_evidence_dict +bp_state = fg.bp_state +bp_state.evidence["grid_vars"] = grid_evidence_arr +bp_state.evidence["additional_vars"] = additional_vars_evidence_dict +run_bp, _, _, decode_map_states = graph.BP(bp_state, 1000) bp_start_time = timer() -final_msgs = fg.run_bp(1000, 0.5, init_msgs=init_msgs) +bp_arrays = run_bp() bp_end_time = timer() print(f"time taken for bp {bp_end_time - bp_start_time}") # Run inference and convert result to human-readable data structure data_writeback_start_time = timer() -map_message_dict = fg.decode_map_states(final_msgs) +map_states = decode_map_states(bp_arrays) data_writeback_end_time = timer() print( f"time taken for data conversion of inference result {data_writeback_end_time - data_writeback_start_time}" @@ -358,13 +359,11 @@ def create_valid_suppression_config_arr(suppression_diameter): for row in range(M): for col in range(N): try: - bp_values[i, row, col] = map_message_dict["grid_vars", i, row, col] + bp_values[i, row, col] = map_states["grid_vars"][i, row, col] bu_evidence[i, row, col, :] = grid_evidence_arr[i, row, col] - except KeyError: + except IndexError: try: - bp_values[i, row, col] = map_message_dict[ - "additional_vars", i, row, col - ] + bp_values[i, row, col] = map_states["additional_vars"][i, row, col] bu_evidence[i, row, col, :] = additional_vars_evidence_dict[ (i, row, col) ] diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 6d1b067e..b2ec5b96 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -18,6 +18,7 @@ Union, ) +import jax import jax.numpy as jnp import numpy as np @@ -121,7 +122,7 @@ def container_keys(self) -> Tuple: """ return (None,) - def flatten(self, data: Any) -> np.ndarray: + def flatten(self, data: Any) -> jnp.ndarray: raise NotImplementedError( "Please subclass the VariableGroup class and override this method" ) @@ -222,8 +223,8 @@ def _get_keys_to_vars(self) -> OrderedDict[Hashable, nodes.Variable]: return keys_to_vars - def flatten(self, data: Union[Mapping, Sequence]) -> np.ndarray: - flat_data = np.concatenate( + def flatten(self, data: Union[Mapping, Sequence]) -> jnp.ndarray: + flat_data = jnp.concatenate( [ self.variable_group_container[key].flatten(data[key]) for key in self.container_keys @@ -323,7 +324,7 @@ def _get_keys_to_vars( return keys_to_vars - def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: + def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: if data.shape != self.shape and data.shape != self.shape + ( self.variable_size, ): @@ -332,9 +333,11 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: f"Got {data.shape}." ) - return data.flatten() + return jax.device_put(data).flatten() - def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: + def unflatten( + self, flat_data: Union[np.ndarray, jnp.ndarray] + ) -> Union[np.ndarray, jnp.ndarray]: if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." @@ -380,7 +383,9 @@ def _get_keys_to_vars(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: return keys_to_vars - def flatten(self, data: Mapping[Hashable, np.ndarray]) -> np.ndarray: + def flatten( + self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] + ) -> jnp.ndarray: for key in data: if key not in self._keys_to_vars: raise ValueError(f"data is referring to a non-existent variable {key}.") @@ -392,12 +397,12 @@ def flatten(self, data: Mapping[Hashable, np.ndarray]) -> np.ndarray: f"Got {data[key].shape}." ) - flat_data = np.concatenate([data[key].flatten() for key in self.keys]) + flat_data = jnp.concatenate([data[key].flatten() for key in self.keys]) return flat_data def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] - ) -> Dict[Hashable, np.ndarray]: + ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." @@ -531,12 +536,14 @@ def factor_num_states(self) -> np.ndarray: ) return factor_num_states - def flatten(self, data: np.ndarray) -> np.ndarray: + def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: raise NotImplementedError( "Please subclass the FactorGroup class and override this method" ) - def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> np.ndarray: + def unflatten( + self, flat_data: Union[np.ndarray, jnp.ndarray] + ) -> Union[np.ndarray, jnp.ndarray]: raise NotImplementedError( "Please subclass the FactorGroup class and override this method" ) @@ -609,7 +616,7 @@ def _get_variables_to_factors( ) return variables_to_factors - def flatten(self, data: np.ndarray) -> np.ndarray: + def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: num_factors = len(self.factors) if data.shape != (num_factors, self.factor_configs.shape[0]) and data.shape != ( num_factors, @@ -621,9 +628,11 @@ def flatten(self, data: np.ndarray) -> np.ndarray: f"Got {data.shape}." ) - return data.flatten() + return jax.device_put(data).flatten() - def unflatten(self, flat_data: np.ndarray) -> np.ndarray: + def unflatten( + self, flat_data: Union[np.ndarray, jnp.ndarray] + ) -> Union[np.ndarray, jnp.ndarray]: if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." @@ -749,7 +758,7 @@ def _get_variables_to_factors( ) return variables_to_factors - def flatten(self, data: np.ndarray) -> np.ndarray: + def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: num_factors = len(self.factors) if data.shape != (num_factors,) + self.log_potential_matrix.shape[ -2: @@ -760,9 +769,11 @@ def flatten(self, data: np.ndarray) -> np.ndarray: f"Got {data.shape}." ) - return data.flatten() + return jax.device_put(data).flatten() - def unflatten(self, flat_data: np.ndarray) -> np.ndarray: + def unflatten( + self, flat_data: Union[np.ndarray, jnp.ndarray] + ) -> Union[np.ndarray, jnp.ndarray]: if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." From b1191e6c2ab59531bbb123f616283d4c55b0bee8 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 15:55:12 -0700 Subject: [PATCH 32/56] Add examples for batching and gradients --- examples/ising_model.py | 28 ++++++++++++++++++++++++++- examples/rbm.py | 30 ++++++++++++++++++++++------- pgmax/fg/groups.py | 42 +++++++++++++++++++++++++++++------------ 3 files changed, 80 insertions(+), 20 deletions(-) diff --git a/examples/ising_model.py b/examples/ising_model.py index 5322b3ac..0ea82493 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -16,6 +16,7 @@ # %% # %matplotlib inline import jax +import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np @@ -47,7 +48,7 @@ # %% bp_state = fg.bp_state -run_bp, _, _, decode_map_states = graph.BP(bp_state, 3000) +run_bp, _, get_beliefs, decode_map_states = graph.BP(bp_state, 3000) # %% bp_arrays = run_bp( @@ -59,6 +60,31 @@ fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(img) + +# %% [markdown] +# ### Gradients and batching + +# %% +def loss(log_potentials_updates, evidence_updates): + bp_arrays = run_bp( + log_potentials_updates=log_potentials_updates, evidence_updates=evidence_updates + ) + beliefs = get_beliefs(bp_arrays) + loss = -jnp.sum(beliefs) + return loss + + +batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0)) +log_potentials_grads = jax.jit(jax.grad(loss, argnums=0)) + +# %% +batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))}) + +# %% +grads = log_potentials_grads( + {"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} +) + # %% [markdown] # ### Message and evidence manipulation diff --git a/examples/rbm.py b/examples/rbm.py index 539d41b4..19cf4076 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -17,6 +17,7 @@ # %matplotlib inline import itertools +import jax import matplotlib.pyplot as plt import numpy as np @@ -50,20 +51,35 @@ run_bp, _, _, decode_map_states = graph.BP(fg.bp_state, 100) # %% -# Run inference and decode -bp_arrays = run_bp( +# Run inference and decode using vmap +n_samples = 16 +bp_arrays = jax.vmap(run_bp, in_axes=0, out_axes=0)( evidence_updates={ "hidden": np.stack( - [np.zeros_like(bh), bh + np.random.logistic(size=bh.shape)], axis=1 + [ + np.zeros((n_samples,) + bh.shape), + bh + np.random.logistic(size=(n_samples,) + bh.shape), + ], + axis=-1, ), "visible": np.stack( - [np.zeros_like(bv), bv + np.random.logistic(size=bv.shape)], axis=1 + [ + np.zeros((n_samples,) + bv.shape), + bv + np.random.logistic(size=(n_samples,) + bv.shape), + ], + axis=-1, ), } ) -map_states = decode_map_states(bp_arrays) +map_states = jax.vmap(decode_map_states, in_axes=0, out_axes=0)(bp_arrays) # %% # Visualize decodings -img = map_states["visible"].reshape((28, 28)) -plt.imshow(img) +fig, ax = plt.subplots(4, 4, figsize=(10, 10)) +for ii in range(16): + ax[np.unravel_index(ii, (4, 4))].imshow( + map_states["visible"][ii].copy().reshape((28, 28)) + ) + ax[np.unravel_index(ii, (4, 4))].axis("off") + +fig.tight_layout() diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index b2ec5b96..d057be3b 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -618,17 +618,27 @@ def _get_variables_to_factors( def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: num_factors = len(self.factors) - if data.shape != (num_factors, self.factor_configs.shape[0]) and data.shape != ( - num_factors, - np.sum(self.factors[0].edges_num_states), + if ( + data.shape != (num_factors, self.factor_configs.shape[0]) + and data.shape + != ( + num_factors, + np.sum(self.factors[0].edges_num_states), + ) + and data.shape != (self.factor_configs.shape[0],) ): raise ValueError( f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or " - f"{( num_factors, np.sum(self.factors[0].edges_num_states))}. " - f"Got {data.shape}." + f"{( num_factors, np.sum(self.factors[0].edges_num_states))} or " + f"(self.factor_configs.shape[0],) . Got {data.shape}." ) - return jax.device_put(data).flatten() + if data.shape == (self.factor_configs.shape[0],): + flat_data = jnp.tile(data, num_factors) + else: + flat_data = jax.device_put(data).flatten() + + return flat_data def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] @@ -760,16 +770,24 @@ def _get_variables_to_factors( def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: num_factors = len(self.factors) - if data.shape != (num_factors,) + self.log_potential_matrix.shape[ - -2: - ] and data.shape != (num_factors, np.sum(self.log_potential_matrix.shape[-2:])): + if ( + data.shape != (num_factors,) + self.log_potential_matrix.shape[-2:] + and data.shape + != (num_factors, np.sum(self.log_potential_matrix.shape[-2:])) + and data.shape != self.log_potential_matrix.shape[-2:] + ): raise ValueError( f"data should be of shape {(num_factors,) + self.log_potential_matrix.shape[-2:]} or " - f"{(num_factors, np.sum(self.log_potential_matrix.shape[-2:]))}. " - f"Got {data.shape}." + f"{(num_factors, np.sum(self.log_potential_matrix.shape[-2:]))} or " + f"{self.log_potential_matrix.shape[-2:]}. Got {data.shape}." ) - return jax.device_put(data).flatten() + if data.shape == self.log_potential_matrix.shape[-2:]: + flat_data = jnp.tile(jax.device_put(data).flatten(), num_factors) + else: + flat_data = jax.device_put(data).flatten() + + return flat_data def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] From 33b9899b69ce89b83ea6eaa75297bc3f5056426a Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 20:33:30 -0700 Subject: [PATCH 33/56] Separate out decode_map_states --- examples/heretic_example.py | 4 ++-- examples/ising_model.py | 4 ++-- examples/rbm.py | 6 ++++-- examples/sanity_check_example.py | 4 ++-- pgmax/fg/graph.py | 20 +++++++++----------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/heretic_example.py b/examples/heretic_example.py index ce4c2629..2208cbb1 100644 --- a/examples/heretic_example.py +++ b/examples/heretic_example.py @@ -186,7 +186,7 @@ def custom_flatten_ordering(Mdown, Mup): ) bp_state.evidence[0] = np.array(bXn_evidence) bp_state.evidence[1] = np.array(bHn_evidence) -run_bp, _, _, decode_map_states = graph.BP(bp_state, 500) +run_bp, _, get_beliefs = graph.BP(bp_state, 500) bp_start_time = timer() # Assign evidence to pixel vars bp_arrays = run_bp() @@ -195,7 +195,7 @@ def custom_flatten_ordering(Mdown, Mup): # Run inference and convert result to human-readable data structure data_writeback_start_time = timer() -map_states = decode_map_states(bp_arrays) +map_states = graph.decode_map_states(get_beliefs(bp_arrays)) data_writeback_end_time = timer() print( f"time taken for data conversion of inference result {data_writeback_end_time - data_writeback_start_time}" diff --git a/examples/ising_model.py b/examples/ising_model.py index 0ea82493..ee702b9b 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -48,7 +48,7 @@ # %% bp_state = fg.bp_state -run_bp, _, get_beliefs, decode_map_states = graph.BP(bp_state, 3000) +run_bp, _, get_beliefs = graph.BP(bp_state, 3000) # %% bp_arrays = run_bp( @@ -56,7 +56,7 @@ ) # %% -img = decode_map_states(bp_arrays) +img = graph.decode_map_states(get_beliefs(bp_arrays)) fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(img) diff --git a/examples/rbm.py b/examples/rbm.py index 19cf4076..46158429 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -48,7 +48,7 @@ ) # %% -run_bp, _, _, decode_map_states = graph.BP(fg.bp_state, 100) +run_bp, _, get_beliefs = graph.BP(fg.bp_state, 100) # %% # Run inference and decode using vmap @@ -71,7 +71,9 @@ ), } ) -map_states = jax.vmap(decode_map_states, in_axes=0, out_axes=0)(bp_arrays) +map_states = graph.decode_map_states( + jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays) +) # %% # Visualize decodings diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index 0d9ffa65..59c676c8 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -333,7 +333,7 @@ def create_valid_suppression_config_arr(suppression_diameter): bp_state = fg.bp_state bp_state.evidence["grid_vars"] = grid_evidence_arr bp_state.evidence["additional_vars"] = additional_vars_evidence_dict -run_bp, _, _, decode_map_states = graph.BP(bp_state, 1000) +run_bp, _, get_beliefs = graph.BP(bp_state, 1000) bp_start_time = timer() bp_arrays = run_bp() bp_end_time = timer() @@ -341,7 +341,7 @@ def create_valid_suppression_config_arr(suppression_diameter): # Run inference and convert result to human-readable data structure data_writeback_start_time = timer() -map_states = decode_map_states(bp_arrays) +map_states = graph.decode_map_states(get_beliefs(bp_arrays)) data_writeback_end_time = timer() print( f"time taken for data conversion of inference result {data_writeback_end_time - data_writeback_start_time}" diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index bf1dc0d8..6abe176c 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -790,15 +790,13 @@ def get_beliefs(bp_arrays: BPArrays): ) return beliefs - @jax.jit - def decode_map_states(bp_arrays: BPArrays): - evidence = jax.device_put(bp_arrays.evidence) - map_states = jax.tree_util.tree_map( - lambda x: jnp.argmax(x, axis=-1), - bp_state.fg_state.variable_group.unflatten( - evidence.at[wiring.var_states_for_edges].add(bp_arrays.ftov_msgs) - ), - ) - return map_states + return run_bp, get_bp_state, get_beliefs + - return run_bp, get_bp_state, get_beliefs, decode_map_states +@jax.jit +def decode_map_states(beliefs: Any): + map_states = jax.tree_util.tree_map( + lambda x: jnp.argmax(x, axis=-1), + beliefs, + ) + return map_states From 7f13dc6d87a464fe51df619c1d5587b61e6df1b7 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 20:59:24 -0700 Subject: [PATCH 34/56] Make test_pgmax pass --- tests/test_pgmax.py | 64 ++++++++++++--------------------------------- 1 file changed, 16 insertions(+), 48 deletions(-) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index eddf1496..0d9e0a2f 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -3,7 +3,6 @@ import jax import jax.numpy as jnp import numpy as np -import pytest from numpy.random import default_rng from scipy.ndimage import gaussian_filter @@ -309,8 +308,6 @@ def create_valid_suppression_config_arr(suppression_diameter): name=(row, col), ) - assert fg.get_factor((0, 0))[1] == 0 - # Create an EnumerationFactorGroup for vertical suppression factors vert_suppression_keys: List[List[Tuple[Any, ...]]] = [] for col in range(N): @@ -365,16 +362,16 @@ def create_valid_suppression_config_arr(suppression_diameter): # Run BP # Set the evidence - init_msgs = fg.get_init_msgs() - init_msgs.evidence["grid_vars"] = grid_evidence_arr - init_msgs.evidence["additional_vars"] = additional_vars_evidence_dict - fg.run_bp(1, 0.5) - one_step_msgs = fg.run_bp(1, 0.5, init_msgs=init_msgs) - final_msgs = fg.run_bp(99, 0.5, one_step_msgs) - + bp_state = fg.bp_state + bp_state.evidence["grid_vars"] = grid_evidence_arr + bp_state.evidence["additional_vars"] = additional_vars_evidence_dict + run_bp, _, get_beliefs = graph.BP(bp_state, 100) + bp_arrays = run_bp() # Test that the output messages are close to the true messages - assert jnp.allclose(final_msgs.ftov.value, true_final_msgs_output, atol=1e-06) - assert fg.decode_map_states(final_msgs) == true_map_state_output + assert jnp.allclose(bp_arrays.ftov_msgs, true_final_msgs_output, atol=1e-06) + decoded_map_states = graph.decode_map_states(get_beliefs(bp_arrays)) + for key in true_map_state_output: + assert true_map_state_output[key] == decoded_map_states[key[0]][key[1:]] def test_e2e_heretic(): @@ -389,7 +386,7 @@ def test_e2e_heretic(): bXn = np.zeros((30, 30, 3)) # Create the factor graph - fg = graph.FactorGraph((pixel_vars, hidden_vars), evidence_default_mode="random") + fg = graph.FactorGraph((pixel_vars, hidden_vars)) def binary_connected_variables( num_hidden_rows, num_hidden_cols, kernel_row, kernel_col @@ -416,39 +413,10 @@ def binary_connected_variables( ) # Assign evidence to pixel vars - init_msgs = fg.get_init_msgs() - init_msgs.evidence[0] = np.array(bXn) - init_msgs.evidence[0, 0, 0] = np.array([0.0, 0.0, 0.0]) - init_msgs.evidence[0, 0, 0] - init_msgs.evidence[1, 0, 0] - with pytest.raises(ValueError) as verror: - fg.get_factor((0, 0)) - - assert "Invalid factor key" in str(verror.value) - with pytest.raises(ValueError) as verror: - fg.get_factor((((0, 0), 0), (10, 20, 30))) - - assert "Invalid factor key" in str(verror.value) - assert isinstance(init_msgs.evidence.value, jnp.ndarray) + bp_state = fg.bp_state + bp_state.evidence[0] = np.array(bXn) + bp_state.evidence[0, 0, 0] = np.array([0.0, 0.0, 0.0]) + bp_state.evidence[0, 0, 0] + bp_state.evidence[1, 0, 0] + assert isinstance(bp_state.evidence.value, jnp.ndarray) assert len(fg.factors) == 7056 - evidence = graph.Evidence(factor_graph=fg) - for ftov_msgs in [ - graph.FToVMessages(factor_graph=fg), - graph.FToVMessages(factor_graph=fg, default_mode="random"), - ]: - ftov_msgs[((0, 0), frozenset([(1, 0, 0), (0, 0, 0)])), (0, 0, 0)] - ftov_msgs[((1, 1), frozenset([(1, 0, 0), (0, 1, 1)])), (1, 0, 0)] = np.ones(17) - assert np.all( - ftov_msgs[((1, 1), frozenset([(1, 0, 0), (0, 1, 1)])), (1, 0, 0)] == 1.0 - ) - ftov_msgs[1, 0, 0] = np.ones(17) - assert np.all( - ftov_msgs[((0, 0), frozenset([(1, 0, 0), (0, 0, 0)])), (1, 0, 0)] == 1.0 / 9 - ) - assert np.all( - ftov_msgs[((1, 1), frozenset([(1, 0, 0), (0, 1, 1)])), (1, 0, 0)] == 1.0 / 9 - ) - msgs = fg.run_bp( - 1, 0.5, init_msgs=graph.Messages(ftov=ftov_msgs, evidence=evidence) - ) - msgs.ftov[((0, 0), frozenset([(1, 0, 0), (0, 0, 0)])), (0, 0, 0)] From 4fe7a1b6acc401f85587b5e9f0bd62caab99767d Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 22:08:04 -0700 Subject: [PATCH 35/56] New test groups --- pgmax/fg/groups.py | 22 ++-- tests/fg/test_groups.py | 221 ++++++++++++++++++++++++++++++---------- 2 files changed, 177 insertions(+), 66 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index d057be3b..4c6bf69a 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -237,7 +237,7 @@ def unflatten( ) -> Union[Mapping, Sequence]: if flat_data.ndim != 1: raise ValueError( - f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." ) num_variables = 0 @@ -340,7 +340,7 @@ def unflatten( ) -> Union[np.ndarray, jnp.ndarray]: if flat_data.ndim != 1: raise ValueError( - f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." ) if flat_data.size == np.product(self.shape): @@ -405,7 +405,7 @@ def unflatten( ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: if flat_data.ndim != 1: raise ValueError( - f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." ) num_variables = len(self.variable_names) @@ -479,7 +479,7 @@ def __getitem__( variables = frozenset(variables) if variables not in self._variables_to_factors: raise ValueError( - f"The queried factor {variables} is not present in the factor group" + f"The queried factor {variables} is not present in the factor group." ) return self._variables_to_factors[variables] @@ -541,9 +541,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: "Please subclass the FactorGroup class and override this method" ) - def unflatten( - self, flat_data: Union[np.ndarray, jnp.ndarray] - ) -> Union[np.ndarray, jnp.ndarray]: + def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: raise NotImplementedError( "Please subclass the FactorGroup class and override this method" ) @@ -629,8 +627,8 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: ): raise ValueError( f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or " - f"{( num_factors, np.sum(self.factors[0].edges_num_states))} or " - f"(self.factor_configs.shape[0],) . Got {data.shape}." + f"{(num_factors, np.sum(self.factors[0].edges_num_states))} or " + f"(self.factor_configs.shape[0],). Got {data.shape}." ) if data.shape == (self.factor_configs.shape[0],): @@ -645,7 +643,7 @@ def unflatten( ) -> Union[np.ndarray, jnp.ndarray]: if flat_data.ndim != 1: raise ValueError( - f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." ) num_factors = len(self.factors) @@ -722,7 +720,7 @@ def _get_variables_to_factors( if len(fac_list) != 2: raise ValueError( "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to" - f" more or less than 2 variables ({fac_list})." + f" {len(fac_list)} variables ({fac_list})." ) if not ( @@ -794,7 +792,7 @@ def unflatten( ) -> Union[np.ndarray, jnp.ndarray]: if flat_data.ndim != 1: raise ValueError( - f"Can only unflatten 1D array. Got an {flat_data.ndim}D array." + f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." ) num_factors = len(self.factors) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index b2d56d71..951bfd80 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -1,76 +1,189 @@ +import jax +import jax.numpy as jnp import numpy as np import pytest from pgmax.fg import groups, nodes -def test_vargroup_list_idx(): - v_group = groups.VariableDict(15, tuple([0, 1, 2])) - assert v_group[[0, 1, 2]][0].num_states == 15 +def test_composite_variable_group(): + variable_dict1 = groups.VariableDict(15, tuple([0, 1, 2])) + variable_dict2 = groups.VariableDict(15, tuple([0, 1, 2])) + composite_variable_sequence = groups.CompositeVariableGroup( + [variable_dict1, variable_dict2] + ) + composite_variable_dict = groups.CompositeVariableGroup( + {(0, 1): variable_dict1, (2, 3): variable_dict2} + ) + with pytest.raises(ValueError, match="The key needs to have at least 2 elements"): + composite_variable_sequence[(0,)] + assert composite_variable_sequence[0, 1] == variable_dict1[1] + assert ( + composite_variable_sequence[[(0, 1), (1, 2)]] + == composite_variable_dict[[(0, 1, 1), (2, 3, 2)]] + ) + assert composite_variable_dict[0, 1, 0] == variable_dict1[0] + assert composite_variable_dict[[(0, 1, 1), (2, 3, 2)]] == [ + variable_dict1[1], + variable_dict2[2], + ] + assert jnp.all( + composite_variable_sequence.flatten( + [{key: np.zeros(15) for key in range(3)} for _ in range(2)] + ) + == composite_variable_dict.flatten( + { + (0, 1): {key: np.zeros(15) for key in range(3)}, + (2, 3): {key: np.zeros(15) for key in range(3)}, + } + ) + ) + assert jnp.all( + jax.tree_util.tree_leaves( + jax.tree_util.tree_multimap( + lambda x, y: jnp.all(x == y), + composite_variable_sequence.unflatten(jnp.zeros(15 * 3 * 2)), + [{key: jnp.zeros(15) for key in range(3)} for _ in range(2)], + ) + ) + ) + assert jnp.all( + jax.tree_util.tree_leaves( + jax.tree_util.tree_multimap( + lambda x, y: jnp.all(x == y), + composite_variable_dict.unflatten(jnp.zeros(15 * 3 * 2)), + { + (0, 1): {key: np.zeros(15) for key in range(3)}, + (2, 3): {key: np.zeros(15) for key in range(3)}, + }, + ) + ) + ) -def test_composite_vargroup_valueerror(): - v_group1 = groups.VariableDict(15, tuple([0, 1, 2])) - v_group2 = groups.VariableDict(15, tuple([0, 1, 2])) - comp_var_group = groups.CompositeVariableGroup(tuple([v_group1, v_group2])) - with pytest.raises(ValueError) as verror: - comp_var_group[tuple([0])] - assert "The key needs" in str(verror.value) +def test_nd_variable_array(): + variable_group = groups.NDVariableArray(2, (1,)) + assert isinstance(variable_group[0], nodes.Variable) + variable_group = groups.NDVariableArray(3, (2, 2)) + with pytest.raises( + ValueError, match="data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)." + ): + variable_group.flatten(np.zeros((3, 3))) -def test_composite_vargroup_evidence(): - v_group1 = groups.VariableDict(3, tuple([0, 1, 2])) - v_group1.container_keys - v_group2 = groups.VariableDict(3, tuple([0, 1, 2])) - comp_var_group = groups.CompositeVariableGroup(tuple([v_group1, v_group2])) - vars_to_evidence = comp_var_group.get_vars_to_evidence( - [{0: np.zeros(3)}, {0: np.zeros(3)}] - ) - assert set(vars_to_evidence.keys()) == set( - [comp_var_group[0, 0], comp_var_group[1, 0]] + assert jnp.all( + variable_group.flatten(np.array([[1, 2], [3, 4]])) == jnp.array([1, 2, 3, 4]) ) - for arr in vars_to_evidence.values(): - assert (arr == np.zeros(3, dtype=float)).all() - + with pytest.rasies(ValueError, "Can only unflatten 1D array. Got a 2D array."): + variable_group.unflatten(np.zeros((10, 20))) + + with pytest.raises( + ValueError, + "flat_data should be compatible with shape (2, 2) or (2, 2, 3). Got (10,).", + ): + variable_group.unflatten(np.zeros((10,))) + + assert jnp.all(variable_group.unflatten(np.zeros(4)) == jnp.zeros((2, 2))) + assert jnp.all(variable_group.unflatten(np.zeros(12)) == jnp.zeros((2, 2, 3))) + + +def test_enumeration_factor_group(): + variable_group = groups.NDVariableArray(3, (2, 2)) + with pytest.raises( + "ValueError", match="Expected log potentials shape: (1,) or (2, 1). Got (3, 2)" + ): + enumeration_factor_group = groups.EnumerationFactorGroup( + variable_group=variable_group, + connected_var_keys=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], + factor_configs=np.zeros((1, 3)), + log_potentials=np.zeros((3, 2)), + ) -def test_1dvararray_indexing(): - v_group = groups.NDVariableArray(2, (1,)) - assert isinstance(v_group[0], nodes.Variable) + enumeration_factor_group = groups.EnumerationFactorGroup( + variable_group=variable_group, + connected_var_keys=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], + factor_configs=np.zeros((1, 3)), + ) + key = [(0, 0), (1, 1)] + with pytest.raises( + ValueError, + match=f"The queried factor {frozenset(key)} is not present in the factor group.", + ): + enumeration_factor_group[key] + + assert ( + enumeration_factor_group[[(0, 1), (1, 0), (1, 1)]] + == enumeration_factor_group.factors[1] + ) + with pytest.raises( + ValueError, "data should be of shape (2, 1) or (2, 9) or (1,). Got (4, 5)." + ): + enumeration_factor_group.factorslatten(np.zeros((4, 5))) + + assert jnp.all(enumeration_factor_group.flatten(np.ones(1)) == jnp.ones(2)) + assert jnp.all(enumeration_factor_group.flatten(np.ones(2, 9)) == jnp.ones(18)) + with pytest.raises( + ValueError, match="Can only unflatten 1D array. Got a 3D array." + ): + enumeration_factor_group.unflatten(jnp.ones((1, 2, 3))) + + with pytest.raises( + ValueError, + match="flat_data should be compatible with shape (2, 1) or (2, 9). Got (30,)", + ): + enumeration_factor_group.unflatten(jnp.zeros(30)) + + assert jnp.all( + enumeration_factor_group.unflatten(jnp.arange(2)) == jnp.array([[0], [1]]) + ) + assert jnp.all(enumeration_factor_group.unflatten(jnp.ones(18)) == jnp.ones((2, 9))) -def test_ndvararray_evidence_error(): - v_group = groups.NDVariableArray(3, (2, 2)) - with pytest.raises(ValueError) as verror: - v_group.get_vars_to_evidence(np.zeros((1, 1))) - assert "Input evidence" in str(verror.value) +def test_pairwise_factor_group(): + variable_group = groups.NDVariableArray(3, (2, 2)) + with pytest.raises( + ValueError, match="log_potential_matrix should be either a 2D array" + ): + groups.PairwiseFactorGroup( + variable_group, [[(0, 0), (1, 1)]], np.zeros((1,), dtype=float) + ) + with pytest.raises( + ValueError, + match="Expected log_potential_matrix for 1 factors. Got log_potential_matrix for 2 factors.", + ): + groups.PairwiseFactorGroup( + variable_group, [[(0, 0), (1, 1)]], np.zeros((2, 3, 3), dtype=float) + ) -def test_pairwisefacgroup_errors(): - v_group = groups.NDVariableArray(3, (2, 2)) - with pytest.raises(ValueError) as verror: + with pytest.raises( + ValueError, + match="All pairwise factors should connect to exactly 2 variables. Got a factor connecting to 3 variables.", + ): groups.PairwiseFactorGroup( - v_group, [[(0, 0), (1, 1), (0, 1)]], np.zeros((1,), dtype=float) + variable_group, [[(0, 0), (1, 1), (0, 1)]], np.zeros((3, 3), dtype=float) ) - assert "All pairwise factors" in str(verror.value) - with pytest.raises(ValueError) as verror: + with pytest.raises( + ValueError, + match="The specified pairwise factor [(0, 0), (1, 1)].", + ): groups.PairwiseFactorGroup( - v_group, [[(0, 0), (1, 1)]], np.zeros((1,), dtype=float) + variable_group, [[(0, 0), (1, 1)]], np.zeros((4, 4), dtype=float) ) - assert "self.log_potential_matrix must" in str(verror.value) - factor_group = groups.PairwiseFactorGroup( - v_group, {0: [(0, 0), (1, 1)]}, np.zeros((3, 3), dtype=float) + + pairwise_factor_group = groups.PairwiseFactorGroup( + variable_group, + [[(0, 0), (1, 1)], [(1, 0), (0, 1)]], + np.zeros((3, 3), dtype=float), + ) + with pytest.raises( + ValueError, + match="data should be of shape (2, 3, 3) or (2, 6) or (3, 3). Got (4, 4).", + ): + pairwise_factor_group.flatten(np.zeros((4, 4))) + + assert jnp.all( + pairwise_factor_group.flatten(np.zeros((3, 3))) == jnp.zeros(2 * 3 * 3) ) - with pytest.raises(ValueError) as verror: - factor_group[1] - assert "The queried factor" in str(verror.value) - - -def test_generic_evidence_errors(): - v_group = groups.VariableDict(3, tuple([0])) - with pytest.raises(ValueError) as verror: - v_group.get_vars_to_evidence({1: np.zeros((1, 1))}) - assert "The evidence is referring" in str(verror.value) - with pytest.raises(ValueError) as verror: - v_group.get_vars_to_evidence({0: np.zeros((1, 1))}) - assert "expects an evidence array" in str(verror.value) + assert jnp.all(pairwise_factor_group.flatten(np.zeros((2, 6))) == jnp.zeros(12)) From 3d8698d4dbf1247388b0182bf90b0a50fdf8d6b1 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 22:52:51 -0700 Subject: [PATCH 36/56] Fix test groups --- pgmax/fg/groups.py | 8 ++-- tests/fg/test_groups.py | 91 ++++++++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 37 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 4c6bf69a..468b1fa9 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -628,7 +628,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: raise ValueError( f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or " f"{(num_factors, np.sum(self.factors[0].edges_num_states))} or " - f"(self.factor_configs.shape[0],). Got {data.shape}." + f"{(self.factor_configs.shape[0],)}. Got {data.shape}." ) if data.shape == (self.factor_configs.shape[0],): @@ -658,7 +658,7 @@ def unflatten( else: raise ValueError( f"flat_data should be compatible with shape {(num_factors, self.factor_configs.shape[0])} " - f"or (num_factors, np.sum(self.factors[0].edges_num_states)). Got {flat_data.shape}." + f"or {(num_factors, np.sum(self.factors[0].edges_num_states))}. Got {flat_data.shape}." ) return data @@ -733,8 +733,8 @@ def _get_variables_to_factors( raise ValueError( f"The specified pairwise factor {fac_list} (with " f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " - "configurations) does not match the specified log_potential_matrix " - "(with {self.log_potential_matrix.shape[-2:]} configurations)." + f"configurations) does not match the specified log_potential_matrix " + f"(with {self.log_potential_matrix.shape[-2:]} configurations)." ) factor_configs = ( diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 951bfd80..bf1d675a 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -1,3 +1,5 @@ +import re + import jax import jax.numpy as jnp import numpy as np @@ -21,10 +23,10 @@ def test_composite_variable_group(): assert composite_variable_sequence[0, 1] == variable_dict1[1] assert ( composite_variable_sequence[[(0, 1), (1, 2)]] - == composite_variable_dict[[(0, 1, 1), (2, 3, 2)]] + == composite_variable_dict[[((0, 1), 1), ((2, 3), 2)]] ) - assert composite_variable_dict[0, 1, 0] == variable_dict1[0] - assert composite_variable_dict[[(0, 1, 1), (2, 3, 2)]] == [ + assert composite_variable_dict[(0, 1), 0] == variable_dict1[0] + assert composite_variable_dict[[((0, 1), 1), ((2, 3), 2)]] == [ variable_dict1[1], variable_dict2[2], ] @@ -40,23 +42,27 @@ def test_composite_variable_group(): ) ) assert jnp.all( - jax.tree_util.tree_leaves( - jax.tree_util.tree_multimap( - lambda x, y: jnp.all(x == y), - composite_variable_sequence.unflatten(jnp.zeros(15 * 3 * 2)), - [{key: jnp.zeros(15) for key in range(3)} for _ in range(2)], + jnp.array( + jax.tree_util.tree_leaves( + jax.tree_util.tree_multimap( + lambda x, y: jnp.all(x == y), + composite_variable_sequence.unflatten(jnp.zeros(15 * 3 * 2)), + [{key: jnp.zeros(15) for key in range(3)} for _ in range(2)], + ) ) ) ) assert jnp.all( - jax.tree_util.tree_leaves( - jax.tree_util.tree_multimap( - lambda x, y: jnp.all(x == y), - composite_variable_dict.unflatten(jnp.zeros(15 * 3 * 2)), - { - (0, 1): {key: np.zeros(15) for key in range(3)}, - (2, 3): {key: np.zeros(15) for key in range(3)}, - }, + jnp.array( + jax.tree_util.tree_leaves( + jax.tree_util.tree_multimap( + lambda x, y: jnp.all(x == y), + composite_variable_dict.unflatten(jnp.zeros(15 * 3 * 2)), + { + (0, 1): {key: np.zeros(15) for key in range(3)}, + (2, 3): {key: np.zeros(15) for key in range(3)}, + }, + ) ) ) ) @@ -67,19 +73,24 @@ def test_nd_variable_array(): assert isinstance(variable_group[0], nodes.Variable) variable_group = groups.NDVariableArray(3, (2, 2)) with pytest.raises( - ValueError, match="data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)." + ValueError, + match=re.escape("data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)."), ): variable_group.flatten(np.zeros((3, 3))) assert jnp.all( variable_group.flatten(np.array([[1, 2], [3, 4]])) == jnp.array([1, 2, 3, 4]) ) - with pytest.rasies(ValueError, "Can only unflatten 1D array. Got a 2D array."): + with pytest.raises( + ValueError, match="Can only unflatten 1D array. Got a 2D array." + ): variable_group.unflatten(np.zeros((10, 20))) with pytest.raises( ValueError, - "flat_data should be compatible with shape (2, 2) or (2, 2, 3). Got (10,).", + match=re.escape( + "flat_data should be compatible with shape (2, 2) or (2, 2, 3). Got (10,)." + ), ): variable_group.unflatten(np.zeros((10,))) @@ -90,24 +101,27 @@ def test_nd_variable_array(): def test_enumeration_factor_group(): variable_group = groups.NDVariableArray(3, (2, 2)) with pytest.raises( - "ValueError", match="Expected log potentials shape: (1,) or (2, 1). Got (3, 2)" + ValueError, + match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"), ): enumeration_factor_group = groups.EnumerationFactorGroup( variable_group=variable_group, connected_var_keys=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], - factor_configs=np.zeros((1, 3)), + factor_configs=np.zeros((1, 3), dtype=int), log_potentials=np.zeros((3, 2)), ) enumeration_factor_group = groups.EnumerationFactorGroup( variable_group=variable_group, connected_var_keys=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], - factor_configs=np.zeros((1, 3)), + factor_configs=np.zeros((1, 3), dtype=int), ) key = [(0, 0), (1, 1)] with pytest.raises( ValueError, - match=f"The queried factor {frozenset(key)} is not present in the factor group.", + match=re.escape( + f"The queried factor {frozenset(key)} is not present in the factor group." + ), ): enumeration_factor_group[key] @@ -116,20 +130,25 @@ def test_enumeration_factor_group(): == enumeration_factor_group.factors[1] ) with pytest.raises( - ValueError, "data should be of shape (2, 1) or (2, 9) or (1,). Got (4, 5)." + ValueError, + match=re.escape( + "data should be of shape (2, 1) or (2, 9) or (1,). Got (4, 5)." + ), ): - enumeration_factor_group.factorslatten(np.zeros((4, 5))) + enumeration_factor_group.flatten(np.zeros((4, 5))) assert jnp.all(enumeration_factor_group.flatten(np.ones(1)) == jnp.ones(2)) - assert jnp.all(enumeration_factor_group.flatten(np.ones(2, 9)) == jnp.ones(18)) + assert jnp.all(enumeration_factor_group.flatten(np.ones((2, 9))) == jnp.ones(18)) with pytest.raises( - ValueError, match="Can only unflatten 1D array. Got a 3D array." + ValueError, match=re.escape("Can only unflatten 1D array. Got a 3D array.") ): enumeration_factor_group.unflatten(jnp.ones((1, 2, 3))) with pytest.raises( ValueError, - match="flat_data should be compatible with shape (2, 1) or (2, 9). Got (30,)", + match=re.escape( + "flat_data should be compatible with shape (2, 1) or (2, 9). Got (30,)" + ), ): enumeration_factor_group.unflatten(jnp.zeros(30)) @@ -142,7 +161,7 @@ def test_enumeration_factor_group(): def test_pairwise_factor_group(): variable_group = groups.NDVariableArray(3, (2, 2)) with pytest.raises( - ValueError, match="log_potential_matrix should be either a 2D array" + ValueError, match=re.escape("log_potential_matrix should be either a 2D array") ): groups.PairwiseFactorGroup( variable_group, [[(0, 0), (1, 1)]], np.zeros((1,), dtype=float) @@ -150,7 +169,9 @@ def test_pairwise_factor_group(): with pytest.raises( ValueError, - match="Expected log_potential_matrix for 1 factors. Got log_potential_matrix for 2 factors.", + match=re.escape( + "Expected log_potential_matrix for 1 factors. Got log_potential_matrix for 2 factors." + ), ): groups.PairwiseFactorGroup( variable_group, [[(0, 0), (1, 1)]], np.zeros((2, 3, 3), dtype=float) @@ -158,7 +179,9 @@ def test_pairwise_factor_group(): with pytest.raises( ValueError, - match="All pairwise factors should connect to exactly 2 variables. Got a factor connecting to 3 variables.", + match=re.escape( + "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to 3 variables" + ), ): groups.PairwiseFactorGroup( variable_group, [[(0, 0), (1, 1), (0, 1)]], np.zeros((3, 3), dtype=float) @@ -166,7 +189,7 @@ def test_pairwise_factor_group(): with pytest.raises( ValueError, - match="The specified pairwise factor [(0, 0), (1, 1)].", + match=re.escape("The specified pairwise factor [(0, 0), (1, 1)]"), ): groups.PairwiseFactorGroup( variable_group, [[(0, 0), (1, 1)]], np.zeros((4, 4), dtype=float) @@ -179,7 +202,9 @@ def test_pairwise_factor_group(): ) with pytest.raises( ValueError, - match="data should be of shape (2, 3, 3) or (2, 6) or (3, 3). Got (4, 4).", + match=re.escape( + "data should be of shape (2, 3, 3) or (2, 6) or (3, 3). Got (4, 4)." + ), ): pairwise_factor_group.flatten(np.zeros((4, 4))) From 86f467ea6dea7fb25a10731fe4d25fab4024d2be Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 23:02:09 -0700 Subject: [PATCH 37/56] New test nodes --- tests/fg/test_nodes.py | 107 ++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 55 deletions(-) diff --git a/tests/fg/test_nodes.py b/tests/fg/test_nodes.py index fb41be5e..1a8b2cf7 100644 --- a/tests/fg/test_nodes.py +++ b/tests/fg/test_nodes.py @@ -1,61 +1,58 @@ +import re + import numpy as np import pytest from pgmax.fg import nodes -def test_enumfactor_configints_error(): - v = nodes.Variable(3) - configs = np.array([[1.0]]) - log_potentials = np.array([1.0]) - - with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v]), configs, log_potentials) - - assert "Configurations" in str(verror.value) - - -def test_enumfactor_potentials_error(): - v = nodes.Variable(3) - configs = np.array([[1]], dtype=int) - log_potentials = np.array([1], dtype=int) - - with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v]), configs, log_potentials) - - assert "Potential" in str(verror.value) - - -def test_enumfactor_configsshape_error(): - v1 = nodes.Variable(3) - v2 = nodes.Variable(3) - configs = np.array([[1]], dtype=int) - log_potentials = np.array([1.0]) - - with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v1, v2]), configs, log_potentials) - - assert "Number of variables" in str(verror.value) - - -def test_enumfactor_potentialshape_error(): - v = nodes.Variable(3) - configs = np.array([[1]], dtype=int) - log_potentials = np.array([1.0, 2.0]) - - with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v]), configs, log_potentials) - - assert "The potential array has" in str(verror.value) - - -def test_enumfactor_configvarsize_error(): - v1 = nodes.Variable(3) - v2 = nodes.Variable(1) - configs = np.array([[-1, 4]], dtype=int) - log_potentials = np.array([1.0]) - - with pytest.raises(ValueError) as verror: - nodes.EnumerationFactor(tuple([v1, v2]), configs, log_potentials) - - assert "Invalid configurations for given variables" in str(verror.value) +def test_enumeration_factor(): + variable = nodes.Variable(3) + with pytest.raises(ValueError, match="Configurations should be integers. Got"): + nodes.EnumerationFactor( + variables=(variable,), + configs=np.array([[1.0]]), + log_potentials=np.array([0.0]), + ) + + with pytest.raises(ValueError, match="Potential should be floats. Got"): + nodes.EnumerationFactor( + variables=(variable,), + configs=np.array([[1]]), + log_potentials=np.array([0]), + ) + + with pytest.raises(ValueError, match="configs should be a 2D array"): + nodes.EnumerationFactor( + variables=(variable,), + configs=np.array([1]), + log_potentials=np.array([0.0]), + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Number of variables 1 doesn't match given configurations (1, 2)" + ), + ): + nodes.EnumerationFactor( + variables=(variable,), + configs=np.array([[1, 2]]), + log_potentials=np.array([0.0]), + ) + + with pytest.raises( + ValueError, match=re.escape("Expected log potentials of shape (1,)") + ): + nodes.EnumerationFactor( + variables=(variable,), + configs=np.array([[1]]), + log_potentials=np.array([0.0, 1.0]), + ) + + with pytest.raises(ValueError, match="Invalid configurations for given variables"): + nodes.EnumerationFactor( + variables=(variable,), + configs=np.array([[10]]), + log_potentials=np.array([0.0]), + ) From eda93fc5016eb9230bb7d39a5fe82d93594d80c8 Mon Sep 17 00:00:00 2001 From: stannis Date: Mon, 25 Oct 2021 23:58:23 -0700 Subject: [PATCH 38/56] Pass test graph --- pgmax/fg/graph.py | 28 ++++---- tests/fg/test_graph.py | 154 ++++++++++++++++++++++++++++++----------- 2 files changed, 130 insertions(+), 52 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 6abe176c..9de393f5 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -179,20 +179,12 @@ def add_factor( self._factor_group_to_msgs_starts[factor_group] + factor_num_states_cumsum[vv] ) - self._factor_group_to_potentials_starts[factor] = ( + self._factor_to_potentials_starts[factor] = ( self._factor_group_to_potentials_starts[factor_group] + vv * factor.log_potentials.shape[0] ) factor_group_num_configs += factor.log_potentials.shape[0] - if ( - factor_group_num_configs - != factor_group.factor_group_log_potentials.shape[0] - ): - raise ValueError( - "Factors in a factor group should have the same number of valid configurations." - ) - self._total_factor_num_states += factor_num_states_cumsum[-1] self._total_factor_num_configs += factor_group_num_configs if name is not None: @@ -378,6 +370,9 @@ def __post_init__(self): object.__setattr__(self, "value", jax.device_put(self.value)) def __getitem__(self, key: Any): + if not isinstance(key, Hashable): + key = frozenset(key) + if key in self.fg_state.named_factor_groups: factor_group = self.fg_state.named_factor_groups[key] start = self.fg_state.factor_group_to_potentials_starts[factor_group] @@ -391,7 +386,7 @@ def __getitem__(self, key: Any): start : start + factor.log_potentials.shape[0] ] else: - raise ValueError("") + raise ValueError(f"Invalid key {key} for log potentials updates.") return log_potentials @@ -400,6 +395,9 @@ def __setitem__( key: Any, data: Union[np.ndarray, jnp.ndarray], ): + if not isinstance(key, Hashable): + key = frozenset(key) + object.__setattr__( self, "value", @@ -428,7 +426,7 @@ def update_ftov_msgs( if data.shape != (variable.num_states,): raise ValueError( f"Given message shape {data.shape} does not match expected " - f"shape f{(variable.num_states,)} from factor {keys[0]} " + f"shape {(variable.num_states,)} from factor {keys[0]} " f"to variable {keys[1]}." ) @@ -438,7 +436,7 @@ def update_ftov_msgs( if data.shape != (variable.num_states,): raise ValueError( f"Given belief shape {data.shape} does not match expected " - f"shape f{(variable.num_states,)} for variable {keys}." + f"shape {(variable.num_states,)} for variable {keys}." ) starts = np.nonzero( @@ -617,6 +615,12 @@ def __post_init__(self): if self.value is None: object.__setattr__(self, "value", jnp.zeros(self.fg_state.num_var_states)) else: + if self.value.shape != (self.fg_state.num_var_states,): + raise ValueError( + f"Expected evidence shape {(self.fg_state.num_var_states,)}. " + f"Got {self.value.shape}." + ) + object.__setattr__(self, "value", jax.device_put(self.value)) def __getitem__(self, key: Any) -> jnp.ndarray: diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index eb30b801..8f02fc8b 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -1,57 +1,131 @@ +import re + +import jax.numpy as jnp import numpy as np import pytest from pgmax.fg import graph, groups -def test_onevar_graph(): - v_group = groups.VariableDict(15, (0,)) - fg = graph.FactorGraph(v_group) - evidence = graph.Evidence(factor_graph=fg, value=np.zeros(1)) - assert np.all(evidence.value == 0) - assert fg._variable_group[0].num_states == 15 - with pytest.raises(ValueError) as verror: - graph.FToVMessages( - factor_graph=fg, default_mode="zeros", init_value=np.zeros(1) +def test_factor_graph(): + variable_group = groups.VariableDict(15, (0,)) + fg = graph.FactorGraph(variable_group) + fg.add_factor([0], np.arange(15)[:, None], name="test") + with pytest.raises( + ValueError, + match="A factor group with the name test already exists. Please choose a different name", + ): + fg.add_factor([0], np.arange(15)[:, None], name="test") + + with pytest.raises( + ValueError, + match=re.escape( + f"A factor involving variables {frozenset([0])} already exists." + ), + ): + fg.add_factor([0], np.arange(10)[:, None]) + + +def test_bp_state(): + variable_group = groups.VariableDict(15, (0,)) + fg0 = graph.FactorGraph(variable_group) + fg0.add_factor([0], np.arange(10)[:, None], name="test") + fg1 = graph.FactorGraph(variable_group) + fg1.add_factor([0], np.arange(15)[:, None], name="test") + with pytest.raises( + ValueError, + match="log_potentials, ftov_msgs and evidence should be derived from the same fg_state", + ): + graph.BPState( + log_potentials=fg0.bp_state.log_potentials, + ftov_msgs=fg1.bp_state.ftov_msgs, + evidence=fg1.bp_state.evidence, ) - assert "Should specify only" in str(verror.value) - with pytest.raises(ValueError) as verror: - graph.FToVMessages(factor_graph=fg, default_mode="test") - assert "Unsupported default message mode" in str(verror.value) - with pytest.raises(ValueError) as verror: - graph.Evidence(factor_graph=fg, default_mode="zeros", value=np.zeros(1)) +def test_log_potentials(): + variable_group = groups.VariableDict(15, (0,)) + fg = graph.FactorGraph(variable_group) + fg.add_factor([0], np.arange(10)[:, None], name="test") + with pytest.raises( + ValueError, + match=re.escape("Expected log potentials shape (10,) for factor group test."), + ): + fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) - assert "Should specify only" in str(verror.value) - with pytest.raises(ValueError) as verror: - graph.Evidence(factor_graph=fg, default_mode="test") + fg.bp_state.log_potentials[[0]] = np.zeros(10) + with pytest.raises( + ValueError, + match=re.escape( + f"Expected log potentials shape (10,) for factor {frozenset([0])}. Got (15,)" + ), + ): + fg.bp_state.log_potentials[[0]] = np.zeros(15) - assert "Unsupported default evidence mode" in str(verror.value) - fg.add_factor([0], np.arange(15)[:, None], name="test") - with pytest.raises(ValueError) as verror: - fg.add_factor([0], np.arange(15)[:, None], name="test") + with pytest.raises( + ValueError, + match=re.escape(f"Invalid key {frozenset([1])} for log potentials updates."), + ): + fg.bp_state.log_potentials[frozenset([1])] = np.zeros(10) + + with pytest.raises( + ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") + ): + graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) + + assert jnp.all(fg.bp_state.log_potentials["test"] == jnp.zeros(10)) + assert jnp.all(fg.bp_state.log_potentials[[0]] == jnp.zeros(10)) + with pytest.raises( + ValueError, + match=re.escape(f"Invalid key {frozenset([1])} for log potentials updates."), + ): + fg.bp_state.log_potentials[[1]] + + +def test_ftov_msgs(): + variable_group = groups.VariableDict(15, (0,)) + fg = graph.FactorGraph(variable_group) + fg.add_factor([0], np.arange(10)[:, None], name="test") + with pytest.raises( + ValueError, + match=re.escape( + f"Given message shape (10,) does not match expected shape (15,) from factor {frozenset([0])} to variable 0" + ), + ): + fg.bp_state.ftov_msgs[[0], 0] = np.ones(10) + + with pytest.raises( + ValueError, + match=re.escape( + "Given belief shape (10,) does not match expected shape (15,) for variable 0" + ), + ): + fg.bp_state.ftov_msgs[0] = np.ones(10) - assert "A factor group with the name" in str(verror.value) - init_msgs = fg.get_init_msgs() - init_msgs.evidence[:] = {0: np.ones(15)} - with pytest.raises(ValueError) as verror: - init_msgs.ftov["test", 1] + with pytest.raises( + ValueError, + match=re.escape("Invalid keys for setting messages"), + ): + fg.bp_state.ftov_msgs[1] = np.ones(10) - assert "Invalid keys" in str(verror.value) - with pytest.raises(ValueError) as verror: - init_msgs.ftov["test", 0] = np.zeros(1) + with pytest.raises( + ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") + ): + graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) - assert "Given message shape" in str(verror.value) - with pytest.raises(ValueError) as verror: - init_msgs.ftov[0] = np.zeros(1) + ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) + with pytest.raises(ValueError, match=re.escape("Invalid keys (10,)")): + ftov_msgs[(10,)] - assert "Given belief shape" in str(verror.value) - with pytest.raises(ValueError) as verror: - init_msgs.ftov[1] = np.zeros(1) - assert "Invalid keys for setting messages" in str(verror.value) - with pytest.raises(ValueError) as verror: - graph.FToVMessages(factor_graph=fg, init_value=np.zeros(1)).value +def test_evidence(): + variable_group = groups.VariableDict(15, (0,)) + fg = graph.FactorGraph(variable_group) + fg.add_factor([0], np.arange(10)[:, None], name="test") + with pytest.raises( + ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).") + ): + graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10)) - assert "Expected messages shape" in str(verror.value) + evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) + assert jnp.all(evidence.value == jnp.zeros(15)) From 5d63a12cb44eebaf578317befcefffb3db816298 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 12:29:33 -0700 Subject: [PATCH 39/56] Full coverage of graph --- tests/fg/test_graph.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 8f02fc8b..b6ab4c1c 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -1,4 +1,5 @@ import re +from dataclasses import replace import jax.numpy as jnp import numpy as np @@ -73,8 +74,9 @@ def test_log_potentials(): ): graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) - assert jnp.all(fg.bp_state.log_potentials["test"] == jnp.zeros(10)) - assert jnp.all(fg.bp_state.log_potentials[[0]] == jnp.zeros(10)) + log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) + assert jnp.all(log_potentials["test"] == jnp.zeros(10)) + assert jnp.all(log_potentials[[0]] == jnp.zeros(10)) with pytest.raises( ValueError, match=re.escape(f"Invalid key {frozenset([1])} for log potentials updates."), @@ -129,3 +131,16 @@ def test_evidence(): evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) assert jnp.all(evidence.value == jnp.zeros(15)) + + +def test_bp(): + variable_group = groups.VariableDict(15, (0,)) + fg = graph.FactorGraph(variable_group) + fg.add_factor([0], np.arange(10)[:, None], name="test") + run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, 1) + bp_arrays = replace( + run_bp(ftov_msgs_updates={(frozenset([0]), 0): np.zeros(15)}), + log_potentials=np.zeros(10), + ) + bp_state = get_bp_state(bp_arrays) + assert bp_state.fg_state == fg.fg_state From 5b7b44ead3383f269866039f1b3e45a21182f21f Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 12:38:34 -0700 Subject: [PATCH 40/56] Support default log_potential_matrix for pairwise factor groups --- pgmax/fg/groups.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 468b1fa9..3f8f8e0e 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -684,7 +684,7 @@ class PairwiseFactorGroup(FactorGroup): """ connected_var_keys: Sequence[List[Tuple[Hashable, ...]]] - log_potential_matrix: np.ndarray + log_potential_matrix: Optional[np.ndarray] = None def _get_variables_to_factors( self, @@ -699,21 +699,29 @@ def _get_variables_to_factors( log_potential_matrix is not the same as the variable sizes for each variable referenced in each sub-list of self.connected_var_keys """ - if not ( - self.log_potential_matrix.ndim == 2 or self.log_potential_matrix.ndim == 3 - ): + if self.log_potential_matrix is None: + log_potential_matrix = np.zeros( + ( + self.variable_group[self.connected_var_keys[0][0]].num_states, + self.variable_group[self.connected_var_keys[0][1]].num_states, + ) + ) + else: + log_potential_matrix = self.log_potential_matrix + + if not (log_potential_matrix.ndim == 2 or log_potential_matrix.ndim == 3): raise ValueError( "log_potential_matrix should be either a 2D array, specifying shared parameters for all " "pairwise factors, or 3D array, specifying parameters for individual pairwise factors. " - f"Got a {self.log_potential_matrix.ndim}D log_potential_matrix array." + f"Got a {log_potential_matrix.ndim}D log_potential_matrix array." ) - if self.log_potential_matrix.ndim == 3 and self.log_potential_matrix.shape[ - 0 - ] != len(self.connected_var_keys): + if log_potential_matrix.ndim == 3 and log_potential_matrix.shape[0] != len( + self.connected_var_keys + ): raise ValueError( f"Expected log_potential_matrix for {len(self.connected_var_keys)} factors. " - f"Got log_potential_matrix for {self.log_potential_matrix.shape[0]} factors." + f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors." ) for fac_list in self.connected_var_keys: @@ -724,7 +732,7 @@ def _get_variables_to_factors( ) if not ( - self.log_potential_matrix.shape[-2:] + log_potential_matrix.shape[-2:] == ( self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states, @@ -734,20 +742,21 @@ def _get_variables_to_factors( f"The specified pairwise factor {fac_list} (with " f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " f"configurations) does not match the specified log_potential_matrix " - f"(with {self.log_potential_matrix.shape[-2:]} configurations)." + f"(with {log_potential_matrix.shape[-2:]} configurations)." ) factor_configs = ( np.mgrid[ - : self.log_potential_matrix.shape[0], - : self.log_potential_matrix.shape[1], + : log_potential_matrix.shape[0], + : log_potential_matrix.shape[1], ] .transpose((1, 2, 0)) .reshape((-1, 2)) ) + object.__setattr__(self, "log_potential_matrix", log_potential_matrix) log_potential_matrix = np.broadcast_to( - self.log_potential_matrix, - (len(self.connected_var_keys),) + self.log_potential_matrix.shape[-2:], + log_potential_matrix, + (len(self.connected_var_keys),) + log_potential_matrix.shape[-2:], ) variables_to_factors = collections.OrderedDict( [ @@ -767,6 +776,7 @@ def _get_variables_to_factors( return variables_to_factors def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + assert isinstance(self.log_potential_matrix, np.ndarray) num_factors = len(self.factors) if ( data.shape != (num_factors,) + self.log_potential_matrix.shape[-2:] @@ -795,6 +805,7 @@ def unflatten( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." ) + assert isinstance(self.log_potential_matrix, np.ndarray) num_factors = len(self.factors) if flat_data.size == num_factors * np.product( self.log_potential_matrix.shape[-2:] From 4e3856ab5f21e9435345bb22cd3a994a4c735c65 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 13:07:16 -0700 Subject: [PATCH 41/56] Full coverage --- pgmax/fg/groups.py | 33 ++++--------------- tests/fg/test_groups.py | 73 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 30 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 3f8f8e0e..8b372eef 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -23,7 +23,6 @@ import numpy as np import pgmax.fg.nodes as nodes -from pgmax.fg import fg_utils from pgmax.utils import cached_property @@ -255,8 +254,8 @@ def unflatten( use_num_states = True else: raise ValueError( - f"flat_data should either be of shape (num_variables={len(self.variables)},), " - f"or (num_variable_states={num_variable_states},). " + f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " + f"or (num_variable_states(={num_variable_states}),). " f"Got {flat_data.shape}" ) @@ -392,9 +391,8 @@ def flatten( if data[key].shape != (self.variable_size,): raise ValueError( - f"Variable {key} expects an data array of shape " - f"({(self.variable_size,)})." - f"Got {data[key].shape}." + f"Variable {key} expects a data array of shape " + f"{(self.variable_size,)}. Got {data[key].shape}." ) flat_data = jnp.concatenate([data[key].flatten() for key in self.keys]) @@ -416,8 +414,8 @@ def unflatten( use_num_states = True else: raise ValueError( - f"flat_data should either be of shape (num_variables={len(self.variables)},), " - f"or (num_variable_states={num_variable_states},). " + f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " + f"or (num_variable_states(={num_variable_states}),). " f"Got {flat_data.shape}" ) @@ -484,23 +482,6 @@ def __getitem__( return self._variables_to_factors[variables] - def compile_wiring( - self, vars_to_starts: Mapping[nodes.Variable, int] - ) -> nodes.EnumerationWiring: - """Function to compile wiring for the factor group. - - Args: - vars_to_starts: A dictionary that maps variables to their global starting indices - For an n-state variable, a global start index of m means the global indices - of its n variable states are m, m + 1, ..., m + n - 1 - - Returns: - compiled wiring for the factor group - """ - wirings = [factor.compile_wiring(vars_to_starts) for factor in self.factors] - wiring = fg_utils.concatenate_enumeration_wirings(wirings) - return wiring - @cached_property def factor_group_log_potentials(self) -> np.ndarray: """Function to compile potential array for the factor group @@ -822,7 +803,7 @@ def unflatten( else: raise ValueError( f"flat_data should be compatible with shape {(num_factors,) + self.log_potential_matrix.shape[-2:]} " - f"or (num_factors, np.sum(self.log_potential_matrix.shape[-2:])). Got {flat_data.shape}." + f"or {(num_factors, np.sum(self.log_potential_matrix.shape[-2:]))}. Got {flat_data.shape}." ) return data diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index bf1d675a..fbf594f8 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -57,15 +57,65 @@ def test_composite_variable_group(): jax.tree_util.tree_leaves( jax.tree_util.tree_multimap( lambda x, y: jnp.all(x == y), - composite_variable_dict.unflatten(jnp.zeros(15 * 3 * 2)), + composite_variable_dict.unflatten(jnp.zeros(3 * 2)), { - (0, 1): {key: np.zeros(15) for key in range(3)}, - (2, 3): {key: np.zeros(15) for key in range(3)}, + (0, 1): {key: np.zeros(1) for key in range(3)}, + (2, 3): {key: np.zeros(1) for key in range(3)}, }, ) ) ) ) + with pytest.raises( + ValueError, match="Can only unflatten 1D array. Got a 2D array." + ): + composite_variable_dict.unflatten(jnp.zeros((10, 20))) + + with pytest.raises( + ValueError, + match=re.escape( + "flat_data should be either of shape (num_variables(=6),), or (num_variable_states(=90),)" + ), + ): + composite_variable_dict.unflatten(jnp.zeros((100))) + + +def test_variable_dict(): + variable_dict = groups.VariableDict(15, tuple([0, 1, 2])) + with pytest.raises( + ValueError, match="data is referring to a non-existent variable 3" + ): + variable_dict.flatten({3: np.zeros(10)}) + + with pytest.raises( + ValueError, + match=re.escape("Variable 2 expects a data array of shape (15,). Got (10,)"), + ): + variable_dict.flatten({2: np.zeros(10)}) + + with pytest.raises( + ValueError, match="Can only unflatten 1D array. Got a 2D array." + ): + variable_dict.unflatten(jnp.zeros((10, 20))) + + assert jnp.all( + jnp.array( + jax.tree_util.tree_leaves( + jax.tree_util.tree_multimap( + lambda x, y: jnp.all(x == y), + variable_dict.unflatten(jnp.zeros(3)), + {key: np.zeros(1) for key in range(3)}, + ) + ) + ) + ) + with pytest.raises( + ValueError, + match=re.escape( + "flat_data should be either of shape (num_variables(=3),), or (num_variable_states(=45),)" + ), + ): + variable_dict.unflatten(jnp.zeros((100))) def test_nd_variable_array(): @@ -198,7 +248,6 @@ def test_pairwise_factor_group(): pairwise_factor_group = groups.PairwiseFactorGroup( variable_group, [[(0, 0), (1, 1)], [(1, 0), (0, 1)]], - np.zeros((3, 3), dtype=float), ) with pytest.raises( ValueError, @@ -212,3 +261,19 @@ def test_pairwise_factor_group(): pairwise_factor_group.flatten(np.zeros((3, 3))) == jnp.zeros(2 * 3 * 3) ) assert jnp.all(pairwise_factor_group.flatten(np.zeros((2, 6))) == jnp.zeros(12)) + with pytest.raises(ValueError, match="Can only unflatten 1D array. Got a 2D array"): + pairwise_factor_group.unflatten(np.zeros((10, 20))) + + assert jnp.all( + pairwise_factor_group.unflatten(np.zeros(2 * 3 * 3)) == jnp.zeros((2, 3, 3)) + ) + assert jnp.all( + pairwise_factor_group.unflatten(np.zeros(2 * 6)) == jnp.zeros((2, 6)) + ) + with pytest.raises( + ValueError, + match=re.escape( + "flat_data should be compatible with shape (2, 3, 3) or (2, 6). Got (10,)." + ), + ): + pairwise_factor_group.unflatten(np.zeros(10)) From 1a356225956966929815665c061f93f9f109379b Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 15:05:54 -0700 Subject: [PATCH 42/56] Docstrings --- pgmax/fg/groups.py | 137 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 134 insertions(+), 3 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 8b372eef..7ed42bec 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -122,11 +122,27 @@ def container_keys(self) -> Tuple: return (None,) def flatten(self, data: Any) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data + + Returns: + A flat jnp.array for internal use + """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" ) def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data + """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" ) @@ -223,6 +239,16 @@ def _get_keys_to_vars(self) -> OrderedDict[Hashable, nodes.Variable]: return keys_to_vars def flatten(self, data: Union[Mapping, Sequence]) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data. + The structure of data should match self.variable_group_container. + + + Returns: + A flat jnp.array for internal use + """ flat_data = jnp.concatenate( [ self.variable_group_container[key].flatten(data[key]) @@ -234,6 +260,14 @@ def flatten(self, data: Union[Mapping, Sequence]) -> jnp.ndarray: def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] ) -> Union[Mapping, Sequence]: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data, with structure matching that of self.variable_group_container. + """ if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." @@ -324,6 +358,15 @@ def _get_keys_to_vars( return keys_to_vars def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data. Should be an array of shape self.shape (for e.g. MAP decodings) + or self.shape + (self.variable_size,) (for e.g. evidence, beliefs). + + Returns: + A flat jnp.array for internal use + """ if data.shape != self.shape and data.shape != self.shape + ( self.variable_size, ): @@ -337,6 +380,15 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] ) -> Union[np.ndarray, jnp.ndarray]: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data. An array of shape self.shape (for e.g. MAP decodings) + or an array of shape self.shape + (self.variable_size,) (for e.g. evidence, beliefs). + """ if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." @@ -385,14 +437,24 @@ def _get_keys_to_vars(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: def flatten( self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] ) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data. Should be a mapping with keys from self.variable_names. + Each value should be an array of shape (1,) (for e.g. MAP decodings) or + (self.variable_size,) (for e.g. evidence, beliefs). + + Returns: + A flat jnp.array for internal use + """ for key in data: if key not in self._keys_to_vars: raise ValueError(f"data is referring to a non-existent variable {key}.") - if data[key].shape != (self.variable_size,): + if data[key].shape != (self.variable_size,) and data[key].shape != (1,): raise ValueError( f"Variable {key} expects a data array of shape " - f"{(self.variable_size,)}. Got {data[key].shape}." + f"{(self.variable_size,)} or (1,). Got {data[key].shape}." ) flat_data = jnp.concatenate([data[key].flatten() for key in self.keys]) @@ -401,6 +463,17 @@ def flatten( def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data. Should be a mapping with keys from self.variable_names. + Each value should be an array of shape (1,) (for e.g. MAP decodings) or + (self.variable_size,) (for e.g. evidence, beliefs). + + """ if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." @@ -426,7 +499,7 @@ def unflatten( data[key] = flat_data[start : start + self.variable_size] start += self.variable_size else: - data[key] = flat_data[start] + data[key] = flat_data[[start]] start += 1 return data @@ -518,11 +591,27 @@ def factor_num_states(self) -> np.ndarray: return factor_num_states def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data. + + Returns: + A flat jnp.array for internal use + """ raise NotImplementedError( "Please subclass the FactorGroup class and override this method" ) def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data. + """ raise NotImplementedError( "Please subclass the FactorGroup class and override this method" ) @@ -596,6 +685,16 @@ def _get_variables_to_factors( return variables_to_factors def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data. Should be an array of shape (num_val_configs,) + (for shared log potentials) or (num_factors, num_val_configs) (for log potentials) + or (num_factors, num_edge_states) (for ftov messages). + + Returns: + A flat jnp.array for internal use + """ num_factors = len(self.factors) if ( data.shape != (num_factors, self.factor_configs.shape[0]) @@ -622,6 +721,16 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] ) -> Union[np.ndarray, jnp.ndarray]: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data. Should be an array of shape (num_val_configs,) + (for shared log potentials) or (num_factors, num_val_configs) (for log potentials) + or (num_factors, num_edge_states) (for ftov messages). + """ if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." @@ -757,6 +866,17 @@ def _get_variables_to_factors( return variables_to_factors def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data. Should be an array of shape + (num_factors, var0_num_states, var1_num_states) (for log potential matrices) + or (num_factors, var0_num_states + var1_num_states) (for ftov messages) + or (var0_num_states, var1_num_states) (for shared log potential matrix). + + Returns: + A flat jnp.array for internal use + """ assert isinstance(self.log_potential_matrix, np.ndarray) num_factors = len(self.factors) if ( @@ -781,6 +901,17 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: def unflatten( self, flat_data: Union[np.ndarray, jnp.ndarray] ) -> Union[np.ndarray, jnp.ndarray]: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data. Should be an array of shape + (num_factors, var0_num_states, var1_num_states) (for log potential matrices) + or (num_factors, var0_num_states + var1_num_states) (for ftov messages) + or (var0_num_states, var1_num_states) (for shared log potential matrix). + """ if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." From dce2b2f8c37f1f3213eb1db61b980b0ba1f1a2ba Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 15:32:37 -0700 Subject: [PATCH 43/56] Docstrings --- pgmax/fg/graph.py | 132 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 92 insertions(+), 40 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 9de393f5..0b68283c 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -39,23 +39,6 @@ class FactorGraph: For a sequence, the indices of the sequence are used to index the variable groups. Note that if not a single VariableGroup, a CompositeVariableGroup will be created from this input, and the individual VariableGroups will need to be accessed by indexing. - - Attributes: - _variable_group: VariableGroup. contains all involved VariableGroups - num_var_states: int. represents the sum of all variable states of all variables in the - FactorGraph - _vars_to_starts: MappingProxyType[nodes.Variable, int]. maps every variable to an int - representing an index in the evidence array at which the first entry of the evidence - for that particular variable should be placed. - _vargroups_set: Set[groups.VariableGroup]. keeps track of all the VariableGroup's that have - been added to this FactorGraph - _named_factor_groups: Dict[Hashable, groups.FactorGroup]. A dictionary mapping the names of - named factor groups to the corresponding factor groups. - We only support setting messages from factors within explicitly named factor groups - to connected variables. - _total_factor_num_states: int. Current total number of edge states for the added factors. - _factor_group_to_msgs_starts: Dict[groups.FactorGroup, int]. Maps a factor group to its - corresponding starting index in the flat message array. """ variables: Union[ @@ -78,7 +61,7 @@ def __post_init__(self): 0, 0, ) - self.num_var_states = vars_num_states_cumsum[-1] + self._num_var_states = vars_num_states_cumsum[-1] self._vars_to_starts = MappingProxyType( { variable: vars_num_states_cumsum[vv] @@ -197,7 +180,7 @@ def wiring(self) -> nodes.EnumerationWiring: If wiring has already beeen compiled, do nothing. Returns: - compiled wiring from each individual factor + Compiled wiring from individual factors. """ wirings = [ factor.compile_wiring(self._vars_to_starts) for factor in self.factors @@ -212,7 +195,7 @@ def log_potentials(self) -> np.ndarray: If potential array has already beeen compiled, do nothing. Returns: - a jnp array representing the log of the potential function for each + A jnp array representing the log of the potential function for each valid configuration """ return np.concatenate( @@ -234,10 +217,11 @@ def factor_groups(self) -> Tuple[groups.FactorGroup, ...]: @cached_property def fg_state(self) -> FactorGraphState: + """Current factor graph state given the added factors.""" return FactorGraphState( variable_group=self._variable_group, vars_to_starts=self._vars_to_starts, - num_var_states=self.num_var_states, + num_var_states=self._num_var_states, total_factor_num_states=self._total_factor_num_states, variables_to_factors=copy.copy(self._variables_to_factors), named_factor_groups=copy.copy(self._named_factor_groups), @@ -252,6 +236,7 @@ def fg_state(self) -> FactorGraphState: @property def bp_state(self) -> BPState: + """Relevant information for doing belief propagation.""" return BPState( log_potentials=LogPotentials(fg_state=self.fg_state), ftov_msgs=FToVMessages(fg_state=self.fg_state), @@ -261,6 +246,32 @@ def bp_state(self) -> BPState: @dataclass(frozen=True, eq=False) class FactorGraphState: + """FactorGraphState. + + Args: + variable_group: VariableGroup. contains all involved VariableGroups + vars_to_starts: MappingProxyType[nodes.Variable, int]. maps every variable to an int + representing an index in the evidence array at which the first entry of the evidence + for that particular variable should be placed. + num_var_states: int. represents the sum of all variable states of all variables in the + FactorGraph + total_factor_num_states: + variables_to_factors: + named_factor_groups: Dict[Hashable, groups.FactorGroup]. A dictionary mapping the names of + named factor groups to the corresponding factor groups. + We only support setting messages from factors within explicitly named factor groups + to connected variables. + factor_group_to_potentials_starts: + factor_to_potentials_starts: + factor_group_to_msgs_starts: + factor_to_msgs_starts: + total_factor_num_states: int. Current total number of edge states for the added factors. + factor_group_to_msgs_starts: Dict[groups.FactorGroup, int]. Maps a factor group to its + corresponding starting index in the flat message array. + log_potentials: + wiring: + """ + variable_group: groups.VariableGroup vars_to_starts: Mapping[nodes.Variable, int] num_var_states: int @@ -290,7 +301,7 @@ class BPState: Args: log_potentials: log potentials of the model ftov_msgs: factor to variable messages - evidence: evidence + evidence: evidence (unary log potentials) for variables. """ log_potentials: LogPotentials @@ -351,6 +362,12 @@ def update_log_potentials( @dataclass(frozen=True, eq=False) class LogPotentials: + """Class for storing and manipulating log potentials. + + Args: + fg_state: Factor graph state + value: Optionally specify an initial value + """ fg_state: FactorGraphState value: Optional[np.ndarray] = None @@ -464,13 +481,9 @@ class FToVMessages: """Class for storing and manipulating factor to variable messages. Args: - factor_graph: associated factor graph + fg_state: Factor graph state value: Optionally specify initial value for ftov messages - Attributes: - _message_updates: Dict[int, jnp.ndarray]. A dictionary containing - the message updates to make on top of initial message values. - Maps starting indices to the message values to update with. """ fg_state: FactorGraphState @@ -600,12 +613,8 @@ class Evidence: """Class for storing and manipulating evidence Args: - factor_graph: associated factor graph + fg_state: Factor graph state value: Optionally specify initial value for evidence - - Attributes: - _evidence_updates: Dict[nodes.Variable, np.ndarray]. maps every variable to an np.ndarray - representing the evidence for that variable """ fg_state: FactorGraphState @@ -676,6 +685,13 @@ def __setitem__( @jax.tree_util.register_pytree_node_class @dataclass(frozen=True, eq=False) class BPArrays: + """Container for the relevant flat arrays used in belief propagation. + + Args: + log_potentials: Flat log potentials array. + ftov_msgs: Flat factor to variable messages array. + evidence: Flat evidence array. + """ log_potentials: Union[np.ndarray, jnp.ndarray] ftov_msgs: Union[np.ndarray, jnp.ndarray] @@ -695,6 +711,19 @@ def tree_unflatten(cls, aux_data, children): def BP(bp_state: BPState, num_iters: int): + """Function for generating belief propagation functions. + + Args: + bp_state: Belief propagation state. + num_iters: Number of belief propagation iterations. + + Returns: + run_bp: Function for running belief propagation for num_iters. + Optionally takes as input log_potentials updates, ftov_msgs updates, + evidence updates, and damping factor, and outputs a BPArrays. + get_bp_state: Function to reconstruct the BPState from BPArrays. + get_beliefs: Function to calculate beliefs from BPArrays. + """ wiring = jax.device_put(bp_state.fg_state.wiring) max_msg_size = int(jnp.max(wiring.edges_num_states)) num_val_configs = int(wiring.factor_configs_edge_states[-1, 0]) + 1 @@ -709,18 +738,17 @@ def run_bp( """Function to perform belief propagation. Specifically, belief propagation is run for num_iters iterations and - returns the resulting messages. + returns a BPArrays containing the updated log_potentials, ftov_msgs and evidence. Args: - num_iters: The number of iterations for which to perform message passing + log_potentials_updates: Dictionary containing optional log_potentials updates. + ftov_msgs_updates: Dictionary containing optional ftov_msgs updates. + evidence_updates: Dictionary containing optional evidence updates. damping: The damping factor to use for message updates between one timestep and the next - bp_state: Initial messages to start the belief propagation. Returns: - ftov messages after running BP for num_iters iterations + A BPArrays containing the updated log_potentials, ftov_msgs and evidence. """ - # Retrieve the necessary data structures from the compiled self.wiring and - # convert these to jax arrays. log_potentials = jax.device_put(bp_state.log_potentials.value) if log_potentials_updates is not None: log_potentials = update_log_potentials( @@ -760,8 +788,8 @@ def update(msgs, _): # update the factor to variable messages delta_msgs = ftov_msgs - msgs msgs = msgs + (1 - damping) * delta_msgs - # Normalize and clip these damped, updated messages before returning - # them. + # Normalize and clip these damped, updated messages before + # returning them. msgs = infer.normalize_and_clip_msgs( msgs, wiring.edges_num_states, @@ -775,6 +803,14 @@ def update(msgs, _): ) def get_bp_state(bp_arrays: BPArrays) -> BPState: + """Reconstruct the BPState from a BPArrays + + Args: + bp_arrays: A BPArrays containing arrays for belief propagation. + + Returns: + The corresponding BPState + """ return BPState( log_potentials=LogPotentials( fg_state=bp_state.fg_state, value=bp_arrays.log_potentials @@ -788,6 +824,14 @@ def get_bp_state(bp_arrays: BPArrays) -> BPState: @jax.jit def get_beliefs(bp_arrays: BPArrays): + """Calculate beliefs from a given BPArrays + + Args: + bp_arrays: A BPArrays containing arrays for belief propagation. + + Returns: + beliefs: An array or a PyTree container containing the beliefs for the variables. + """ evidence = jax.device_put(bp_arrays.evidence) beliefs = bp_state.fg_state.variable_group.unflatten( evidence.at[wiring.var_states_for_edges].add(bp_arrays.ftov_msgs) @@ -799,6 +843,14 @@ def get_beliefs(bp_arrays: BPArrays): @jax.jit def decode_map_states(beliefs: Any): + """Function to decode MAP states given the calculated beliefs. + + Args: + beliefs: An array or a PyTree container containing beliefs for different variables. + + Returns: + An array or a PyTree container containing the MAP states for different variables. + """ map_states = jax.tree_util.tree_map( lambda x: jnp.argmax(x, axis=-1), beliefs, From 802d9b3776e858283e21fdc559bb8dcdab978ca2 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 17:16:55 -0700 Subject: [PATCH 44/56] Separate add factor functions to clarify --- pgmax/fg/graph.py | 193 +++++++++++++++++++++++++++------------------ pgmax/fg/groups.py | 6 +- 2 files changed, 118 insertions(+), 81 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 0b68283c..31bdc7b2 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -8,9 +8,11 @@ from types import MappingProxyType from typing import ( Any, + Callable, Dict, FrozenSet, Hashable, + List, Mapping, Optional, OrderedDict, @@ -94,54 +96,62 @@ def __hash__(self) -> int: def add_factor( self, - *args, - **kwargs, + variable_names: List, + factor_configs: np.ndarray, + log_potentials: Optional[np.ndarray] = None, + name: Optional[str] = None, ) -> None: - """Function to add factor/factor group to this FactorGraph. + """Function to add a single factor to the FactorGraph. Args: - *args: optional sequence of arguments. If specified, and if there is no - "factor_factory" key specified as part of the **kwargs, then these args - are taken to specify the arguments to be used to instantiate an - EnumerationFactor. If there is a "factor_factory" key, then these args - are taken to specify the arguments to be used to construct the class - specified by the "factor_factory" argument. Note that either *args or - **kwargs must be specified. - **kwargs: optional mapping of keyword arguments. If specified, and if there - is no "factor_factory" key specified as part of this mapping, then these - args are taken to specify the arguments to be used to instantiate an - EnumerationFactor (specify a kwarg with the key 'keys' to indicate the - indices of variables ot be indexed to create the EnumerationFactor). - If there is a "factor_factory" key, then these args are taken to specify - the arguments to be used to construct the class specified by the - "factor_factory" argument. - If there is a "name" key, we add the added factor/factor group to the list - of named factors within the factor graph. - Note that either *args or **kwargs must be specified. + variable_names: A list containing the involved variable names. + factor_configs: Array of shape (num_val_configs, num_variables) + An array containing explicit enumeration of all valid configurations + log_potentials: Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). + If specified, it contains the log of the potential value for every possible configuration. + If none, it is assumed the log potential is uniform 0 and such an array is automatically + initialized. + """ + if name in self._named_factor_groups: + raise ValueError( + f"A factor group with the name {name} already exists. Please choose a different name!" + ) + + factor_group = groups.EnumerationFactorGroup( + self._variable_group, + connected_var_keys=[variable_names], + factor_configs=factor_configs, + log_potentials=log_potentials, + ) + self._register_factor_group(factor_group) + + def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: + """Add a factor group to the factor graph + + Args: + factory: Factory function that takes args and kwargs as input and outputs a factor group. + args: Args to be passed to the factory function. + kwargs: kwargs to be passed to the factory function, and an optional "name" argument + for specifying the name of a named factor group. """ name = kwargs.pop("name", None) + factor_group = factory(self._variable_group, *args, **kwargs) + self._register_factor_group(factor_group, name) + + def _register_factor_group( + self, factor_group: groups.FactorGroup, name: Optional[str] = None + ) -> None: if name in self._named_factor_groups: raise ValueError( f"A factor group with the name {name} already exists. Please choose a different name!" ) - factor_factory = kwargs.pop("factor_factory", None) - if factor_factory is not None: - factor_group = factor_factory(self._variable_group, *args, **kwargs) - else: - if len(args) > 0: - new_args = list(args) - new_args[0] = [args[0]] - factor_group = groups.EnumerationFactorGroup( - self._variable_group, *new_args, **kwargs - ) - else: - keys = kwargs.pop("keys") - kwargs["connected_var_keys"] = [keys] - factor_group = groups.EnumerationFactorGroup( - self._variable_group, **kwargs - ) + """Register a factor group to the factor graph, by updating the factor graph state. + Args: + factor_group: The factor group to be registered to the factor graph. + name: Optional name of the factor group. + """ self._factor_group_to_msgs_starts[factor_group] = self._total_factor_num_states self._factor_group_to_potentials_starts[ factor_group @@ -249,27 +259,20 @@ class FactorGraphState: """FactorGraphState. Args: - variable_group: VariableGroup. contains all involved VariableGroups - vars_to_starts: MappingProxyType[nodes.Variable, int]. maps every variable to an int - representing an index in the evidence array at which the first entry of the evidence - for that particular variable should be placed. - num_var_states: int. represents the sum of all variable states of all variables in the - FactorGraph - total_factor_num_states: - variables_to_factors: - named_factor_groups: Dict[Hashable, groups.FactorGroup]. A dictionary mapping the names of - named factor groups to the corresponding factor groups. - We only support setting messages from factors within explicitly named factor groups - to connected variables. - factor_group_to_potentials_starts: - factor_to_potentials_starts: - factor_group_to_msgs_starts: - factor_to_msgs_starts: - total_factor_num_states: int. Current total number of edge states for the added factors. - factor_group_to_msgs_starts: Dict[groups.FactorGroup, int]. Maps a factor group to its - corresponding starting index in the flat message array. - log_potentials: - wiring: + variable_group: A variable group containing all the variables in the FactorGraph. + vars_to_starts: Maps variables to their starting indices in the flat evidence array. + flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] + contains evidence to the variable. + num_var_states: Total number of variable states. + total_factor_num_states: Size of the flat ftov messages array. + variables_to_factors: Maps sets of involved variables (in the form of frozensets of + variable names) to corresponding factors. + named_factor_groups: Maps the names of named factor groups to the corresponding factor groups. + factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials. + factor_to_potentials_starts: Maps factors to their starting indices in the flat log potentials. + factor_to_msgs_starts: Maps factors to their starting indices in the flat ftov messages. + log_potentials: Flat log potentials array. + wiring: Wiring derived from the current set of factors. """ variable_group: groups.VariableGroup @@ -327,6 +330,16 @@ def update_log_potentials( updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState, ) -> jnp.ndarray: + """Function to update log_potentials. + + Args: + log_potentials: A flat jnp array containing log_potentials. + updates: A dictionary containing updates for log_potentials + fg_state: Factor graph state + + Returns: + A flat jnp array containing updated log_potentials. + """ for key in updates: data = updates[key] if key in fg_state.named_factor_groups: @@ -387,6 +400,15 @@ def __post_init__(self): object.__setattr__(self, "value", jax.device_put(self.value)) def __getitem__(self, key: Any): + """Function to query log potentials for a named factor group or a factor. + + Args: + key: Name of a named factor group, or a frozenset containing the set + of involved variables for the queried factor. + + Returned: + The quried log potentials. + """ if not isinstance(key, Hashable): key = frozenset(key) @@ -412,6 +434,14 @@ def __setitem__( key: Any, data: Union[np.ndarray, jnp.ndarray], ): + """Set the log potentials for a named factor group or a factor. + + Args: + key: Name of a named factor group, or a frozenset containing the set + of involved variables for the queried factor. + data: Array containing the log potentials for the named factor group + or the factor. + """ if not isinstance(key, Hashable): key = frozenset(key) @@ -428,6 +458,16 @@ def __setitem__( def update_ftov_msgs( ftov_msgs: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState ) -> jnp.ndarray: + """Function to update ftov_msgs. + + Args: + ftov_msgs: A flat jnp array containing ftov_msgs. + updates: A dictionary containing updates for ftov_msgs + fg_state: Factor graph state + + Returns: + A flat jnp array containing updated ftov_msgs. + """ for keys in updates: data = updates[keys] if ( @@ -483,7 +523,6 @@ class FToVMessages: Args: fg_state: Factor graph state value: Optionally specify initial value for ftov messages - """ fg_state: FactorGraphState @@ -584,6 +623,16 @@ def __setitem__(self, keys, data) -> None: def update_evidence( evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState ) -> jnp.ndarray: + """Function to update evidence. + + Args: + evidence: A flat jnp array containing evidence. + updates: A dictionary containing updates for evidence + fg_state: Factor graph state + + Returns: + A flat jnp array containing updated evidence. + """ for key in updates: data = updates[key] if key in fg_state.variable_group.container_keys: @@ -649,29 +698,17 @@ def __getitem__(self, key: Any) -> jnp.ndarray: def __setitem__( self, key: Any, - data: Union[Dict[Hashable, np.ndarray], np.ndarray], + data: np.ndarray, ) -> None: """Function to update the evidence for variables Args: - key: tuple that represents the index into the VariableGroup - (self.fg_state.variable_group) that is created when the FactorGraph is instantiated. Note that - this can be an index referring to an entire VariableGroup (in which case, the evidence - is set for the entire VariableGroup at once), or to an individual Variable within the - VariableGroup. - data: a container for np.ndarrays representing the evidence - Currently supported containers are: - - an np.ndarray: if key indexes an NDVariableArray, then data - can simply be an np.ndarray with num_var_array_dims + 1 dimensions where - num_var_array_dims is the number of dimensions of the NDVariableArray, and the - +1 represents a dimension (that should be the final dimension) for the evidence. - Note that the size of the final dimension should be the same as - variable_group.variable_size. if key indexes a particular variable, then this array - must be of the same size as variable.num_states - - a dictionary: if key indexes a VariableDict, then data - must be a dictionary mapping keys of variable_group to np.ndarrays of evidence values. - Note that each np.ndarray in the dictionary values must have the same size as - variable_group.variable_size. + key: The name of a variable group or a single variable. + If key is the name of a variable group, updates are derived by using the variable group to + flatten the data. + If key is the name of a variable, data should be of an array shape (variable_size,) + If key is None, updates are derived by using self.fg_state.variable_group to flatten the data. + data: Array containing the evidence updates. """ object.__setattr__( self, diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 7ed42bec..c5970008 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -499,7 +499,7 @@ def unflatten( data[key] = flat_data[start : start + self.variable_size] start += self.variable_size else: - data[key] = flat_data[[start]] + data[key] = flat_data[np.array([start])] start += 1 return data @@ -637,7 +637,7 @@ class EnumerationFactorGroup(FactorGroup): initialized. """ - connected_var_keys: Sequence[List[Tuple[Hashable, ...]]] + connected_var_keys: Sequence[List] factor_configs: np.ndarray log_potentials: Optional[np.ndarray] = None @@ -773,7 +773,7 @@ class PairwiseFactorGroup(FactorGroup): VariableGroup) whose keys are present in each sub-list from self.connected_var_keys. """ - connected_var_keys: Sequence[List[Tuple[Hashable, ...]]] + connected_var_keys: Sequence[List] log_potential_matrix: Optional[np.ndarray] = None def _get_variables_to_factors( From 5ab3a3b8b77d5c48bf025199aef26ace20bac2ba Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 18:53:47 -0700 Subject: [PATCH 45/56] Update examples --- examples/heretic_example.py | 4 ++-- examples/ising_model.py | 4 ++-- examples/sanity_check_example.py | 8 ++++---- tests/test_pgmax.py | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/heretic_example.py b/examples/heretic_example.py index 2208cbb1..5e520f35 100644 --- a/examples/heretic_example.py +++ b/examples/heretic_example.py @@ -115,8 +115,8 @@ def binary_connected_variables( W_pot = W_orig.swapaxes(0, 1) for k_row in range(3): for k_col in range(3): - fg.add_factor( - factor_factory=groups.PairwiseFactorGroup, + fg.add_factor_group( + factory=groups.PairwiseFactorGroup, connected_var_keys=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], ) diff --git a/examples/ising_model.py b/examples/ising_model.py index ee702b9b..23c8b46f 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -36,8 +36,8 @@ connected_var_keys.append([(ii, jj), (kk, jj)]) connected_var_keys.append([(ii, jj), (ii, ll)]) -fg.add_factor( - factor_factory=groups.PairwiseFactorGroup, +fg.add_factor_group( + factory=groups.PairwiseFactorGroup, connected_var_keys=connected_var_keys, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), name="factors", diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index 59c676c8..deb06921 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -313,13 +313,13 @@ def create_valid_suppression_config_arr(suppression_diameter): # ### Add FactorGroups Remaining to FactorGraph # %% -fg.add_factor( - factor_factory=groups.EnumerationFactorGroup, +fg.add_factor_group( + factory=groups.EnumerationFactorGroup, connected_var_keys=vert_suppression_keys, factor_configs=valid_configs_supp, ) -fg.add_factor( - factor_factory=groups.EnumerationFactorGroup, +fg.add_factor_group( + factory=groups.EnumerationFactorGroup, connected_var_keys=horz_suppression_keys, factor_configs=valid_configs_supp, ) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 0d9e0a2f..6f3681d2 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -346,15 +346,15 @@ def create_valid_suppression_config_arr(suppression_diameter): ) # Add the suppression factors to the graph via kwargs - fg.add_factor( - factor_factory=groups.EnumerationFactorGroup, + fg.add_factor_group( + factory=groups.EnumerationFactorGroup, connected_var_keys={ idx: keys for idx, keys in enumerate(vert_suppression_keys) }, factor_configs=valid_configs_supp, ) - fg.add_factor( - factor_factory=groups.EnumerationFactorGroup, + fg.add_factor_group( + factory=groups.EnumerationFactorGroup, connected_var_keys=horz_suppression_keys, factor_configs=valid_configs_supp, log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), @@ -406,7 +406,7 @@ def binary_connected_variables( for k_row in range(3): for k_col in range(3): fg.add_factor( - factor_factory=groups.PairwiseFactorGroup, + factory=groups.PairwiseFactorGroup, connected_var_keys=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], name=(k_row, k_col), From 030ca0337d4f8f4bbdaee7f40477fe0ab1fb12cc Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 19:00:18 -0700 Subject: [PATCH 46/56] Fix tests --- pgmax/fg/graph.py | 7 +------ tests/fg/test_groups.py | 4 +++- tests/test_pgmax.py | 4 ++-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 31bdc7b2..742a0fe5 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -112,18 +112,13 @@ def add_factor( If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized. """ - if name in self._named_factor_groups: - raise ValueError( - f"A factor group with the name {name} already exists. Please choose a different name!" - ) - factor_group = groups.EnumerationFactorGroup( self._variable_group, connected_var_keys=[variable_names], factor_configs=factor_configs, log_potentials=log_potentials, ) - self._register_factor_group(factor_group) + self._register_factor_group(factor_group, name) def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: """Add a factor group to the factor graph diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index fbf594f8..4380bb80 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -89,7 +89,9 @@ def test_variable_dict(): with pytest.raises( ValueError, - match=re.escape("Variable 2 expects a data array of shape (15,). Got (10,)"), + match=re.escape( + "Variable 2 expects a data array of shape (15,) or (1,). Got (10,)" + ), ): variable_dict.flatten({2: np.zeros(10)}) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 6f3681d2..646d204c 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -300,7 +300,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ) else: fg.add_factor( - keys=curr_keys, + variable_names=curr_keys, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float @@ -405,7 +405,7 @@ def binary_connected_variables( W_pot = np.zeros((17, 3, 3, 3), dtype=float) for k_row in range(3): for k_col in range(3): - fg.add_factor( + fg.add_factor_group( factory=groups.PairwiseFactorGroup, connected_var_keys=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], From ed952899e2fee555ce82c1b9c762f139c9f23eaa Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 20:31:48 -0700 Subject: [PATCH 47/56] Change key to name --- examples/heretic_example.py | 2 +- examples/ising_model.py | 8 +- examples/sanity_check_example.py | 36 ++-- pgmax/fg/graph.py | 160 +++++++++--------- pgmax/fg/groups.py | 280 ++++++++++++++++--------------- tests/fg/test_graph.py | 8 +- tests/fg/test_groups.py | 26 +-- tests/test_pgmax.py | 46 ++--- 8 files changed, 290 insertions(+), 276 deletions(-) diff --git a/examples/heretic_example.py b/examples/heretic_example.py index 5e520f35..1a888a7a 100644 --- a/examples/heretic_example.py +++ b/examples/heretic_example.py @@ -117,7 +117,7 @@ def binary_connected_variables( for k_col in range(3): fg.add_factor_group( factory=groups.PairwiseFactorGroup, - connected_var_keys=binary_connected_variables(28, 28, k_row, k_col), + connected_var_names=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], ) diff --git a/examples/ising_model.py b/examples/ising_model.py index 23c8b46f..27634072 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -28,17 +28,17 @@ # %% variables = groups.NDVariableArray(variable_size=2, shape=(50, 50)) fg = graph.FactorGraph(variables=variables) -connected_var_keys = [] +connected_var_names = [] for ii in range(50): for jj in range(50): kk = (ii + 1) % 50 ll = (jj + 1) % 50 - connected_var_keys.append([(ii, jj), (kk, jj)]) - connected_var_keys.append([(ii, jj), (ii, ll)]) + connected_var_names.append([(ii, jj), (kk, jj)]) + connected_var_names.append([(ii, jj), (ii, ll)]) fg.add_factor_group( factory=groups.PairwiseFactorGroup, - connected_var_keys=connected_var_keys, + connected_var_names=connected_var_names, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), name="factors", ) diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index deb06921..cbdcec3d 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -103,14 +103,14 @@ grid_vars_group = groups.NDVariableArray(3, (2, M - 1, N - 1)) # Make a group of additional variables for the edges of the grid -extra_row_keys: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)] -extra_col_keys: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)] -additional_keys = tuple(extra_row_keys + extra_col_keys) -additional_keys_group = groups.VariableDict(3, additional_keys) +extra_row_names: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)] +extra_col_names: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)] +additional_names = tuple(extra_row_names + extra_col_names) +additional_names_group = groups.VariableDict(3, additional_names) # Combine these two VariableGroups into one CompositeVariableGroup composite_grid_group = groups.CompositeVariableGroup( - {"grid_vars": grid_vars_group, "additional_vars": additional_keys_group} + {"grid_vars": grid_vars_group, "additional_vars": additional_names_group} ) @@ -233,14 +233,14 @@ def create_valid_suppression_config_arr(suppression_diameter): for row in range(M - 1): for col in range(N - 1): if row != M - 2 and col != N - 2: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("grid_vars", 0, row, col + 1), ("grid_vars", 1, row + 1, col), ] elif row != M - 2: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("additional_vars", 0, row, col + 1), @@ -248,7 +248,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ] elif col != N - 2: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("grid_vars", 0, row, col + 1), @@ -256,7 +256,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ] else: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("additional_vars", 0, row, col + 1), @@ -264,25 +264,25 @@ def create_valid_suppression_config_arr(suppression_diameter): ] fg.add_factor( - curr_keys, + curr_names, valid_configs_non_supp, np.zeros(valid_configs_non_supp.shape[0], dtype=float), ) # Create an EnumerationFactorGroup for vertical suppression factors -vert_suppression_keys: List[List[Tuple[Any, ...]]] = [] +vert_suppression_names: List[List[Tuple[Any, ...]]] = [] for col in range(N): for start_row in range(M - SUPPRESSION_DIAMETER): if col != N - 1: - vert_suppression_keys.append( + vert_suppression_names.append( [ ("grid_vars", 0, r, col) for r in range(start_row, start_row + SUPPRESSION_DIAMETER) ] ) else: - vert_suppression_keys.append( + vert_suppression_names.append( [ ("additional_vars", 0, r, col) for r in range(start_row, start_row + SUPPRESSION_DIAMETER) @@ -290,18 +290,18 @@ def create_valid_suppression_config_arr(suppression_diameter): ) -horz_suppression_keys: List[List[Tuple[Any, ...]]] = [] +horz_suppression_names: List[List[Tuple[Any, ...]]] = [] for row in range(M): for start_col in range(N - SUPPRESSION_DIAMETER): if row != M - 1: - horz_suppression_keys.append( + horz_suppression_names.append( [ ("grid_vars", 1, row, c) for c in range(start_col, start_col + SUPPRESSION_DIAMETER) ] ) else: - horz_suppression_keys.append( + horz_suppression_names.append( [ ("additional_vars", 1, row, c) for c in range(start_col, start_col + SUPPRESSION_DIAMETER) @@ -315,12 +315,12 @@ def create_valid_suppression_config_arr(suppression_diameter): # %% fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_keys=vert_suppression_keys, + connected_var_names=vert_suppression_names, factor_configs=valid_configs_supp, ) fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_keys=horz_suppression_keys, + connected_var_names=horz_suppression_names, factor_configs=valid_configs_supp, ) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 742a0fe5..98ea5709 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -114,7 +114,7 @@ def add_factor( """ factor_group = groups.EnumerationFactorGroup( self._variable_group, - connected_var_keys=[variable_names], + connected_var_names=[variable_names], factor_configs=factor_configs, log_potentials=log_potentials, ) @@ -335,27 +335,27 @@ def update_log_potentials( Returns: A flat jnp array containing updated log_potentials. """ - for key in updates: - data = updates[key] - if key in fg_state.named_factor_groups: - factor_group = fg_state.named_factor_groups[key] + for name in updates: + data = updates[name] + if name in fg_state.named_factor_groups: + factor_group = fg_state.named_factor_groups[name] flat_data = factor_group.flatten(data) if flat_data.shape != factor_group.factor_group_log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " - f"for factor group {key}. Got incompatible data shape {data.shape}." + f"for factor group {name}. Got incompatible data shape {data.shape}." ) start = fg_state.factor_group_to_potentials_starts[factor_group] log_potentials = log_potentials.at[start : start + flat_data.shape[0]].set( flat_data ) - elif frozenset(key) in fg_state.variables_to_factors: - factor = fg_state.variables_to_factors[frozenset(key)] + elif frozenset(name) in fg_state.variables_to_factors: + factor = fg_state.variables_to_factors[frozenset(name)] if data.shape != factor.log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor.log_potentials.shape} " - f"for factor {key}. Got {data.shape}." + f"for factor {name}. Got {data.shape}." ) start = fg_state.factor_to_potentials_starts[factor] @@ -363,7 +363,7 @@ def update_log_potentials( start : start + factor.log_potentials.shape[0] ].set(data) else: - raise ValueError(f"Invalid key {key} for log potentials updates.") + raise ValueError(f"Invalid name {name} for log potentials updates.") return log_potentials @@ -394,57 +394,57 @@ def __post_init__(self): object.__setattr__(self, "value", jax.device_put(self.value)) - def __getitem__(self, key: Any): + def __getitem__(self, name: Any): """Function to query log potentials for a named factor group or a factor. Args: - key: Name of a named factor group, or a frozenset containing the set + name: Name of a named factor group, or a frozenset containing the set of involved variables for the queried factor. Returned: The quried log potentials. """ - if not isinstance(key, Hashable): - key = frozenset(key) + if not isinstance(name, Hashable): + name = frozenset(name) - if key in self.fg_state.named_factor_groups: - factor_group = self.fg_state.named_factor_groups[key] + if name in self.fg_state.named_factor_groups: + factor_group = self.fg_state.named_factor_groups[name] start = self.fg_state.factor_group_to_potentials_starts[factor_group] log_potentials = jax.device_put(self.value)[ start : start + factor_group.factor_group_log_potentials.shape[0] ] - elif frozenset(key) in self.fg_state.variables_to_factors: - factor = self.fg_state.variables_to_factors[frozenset(key)] + elif frozenset(name) in self.fg_state.variables_to_factors: + factor = self.fg_state.variables_to_factors[frozenset(name)] start = self.fg_state.factor_to_potentials_starts[factor] log_potentials = jax.device_put(self.value)[ start : start + factor.log_potentials.shape[0] ] else: - raise ValueError(f"Invalid key {key} for log potentials updates.") + raise ValueError(f"Invalid name {name} for log potentials updates.") return log_potentials def __setitem__( self, - key: Any, + name: Any, data: Union[np.ndarray, jnp.ndarray], ): """Set the log potentials for a named factor group or a factor. Args: - key: Name of a named factor group, or a frozenset containing the set + name: Name of a named factor group, or a frozenset containing the set of involved variables for the queried factor. data: Array containing the log potentials for the named factor group or the factor. """ - if not isinstance(key, Hashable): - key = frozenset(key) + if not isinstance(name, Hashable): + name = frozenset(name) object.__setattr__( self, "value", update_log_potentials( - jax.device_put(self.value), {key: jax.device_put(data)}, self.fg_state + jax.device_put(self.value), {name: jax.device_put(data)}, self.fg_state ), ) @@ -463,32 +463,32 @@ def update_ftov_msgs( Returns: A flat jnp array containing updated ftov_msgs. """ - for keys in updates: - data = updates[keys] + for names in updates: + data = updates[names] if ( - isinstance(keys, tuple) - and len(keys) == 2 - and keys[1] in fg_state.variable_group.keys + isinstance(names, tuple) + and len(names) == 2 + and names[1] in fg_state.variable_group.names ): - factor = fg_state.variables_to_factors[frozenset(keys[0])] - variable = fg_state.variable_group[keys[1]] + factor = fg_state.variables_to_factors[frozenset(names[0])] + variable = fg_state.variable_group[names[1]] start = fg_state.factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) if data.shape != (variable.num_states,): raise ValueError( f"Given message shape {data.shape} does not match expected " - f"shape {(variable.num_states,)} from factor {keys[0]} " - f"to variable {keys[1]}." + f"shape {(variable.num_states,)} from factor {names[0]} " + f"to variable {names[1]}." ) ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set(data) - elif keys in fg_state.variable_group.keys: - variable = fg_state.variable_group[keys] + elif names in fg_state.variable_group.names: + variable = fg_state.variable_group[names] if data.shape != (variable.num_states,): raise ValueError( f"Given belief shape {data.shape} does not match expected " - f"shape {(variable.num_states,)} for variable {keys}." + f"shape {(variable.num_states,)} for variable {names}." ) starts = np.nonzero( @@ -501,10 +501,10 @@ def update_ftov_msgs( ) else: raise ValueError( - "Invalid keys for setting messages. " - "Supported keys include a tuple of length 2 with factor " - "and variable keys for directly setting factor to variable " - "messages, or a valid variable key for spreading expected " + "Invalid names for setting messages. " + "Supported names include a tuple of length 2 with factor " + "and variable names for directly setting factor to variable " + "messages, or a valid variable name for spreading expected " "beliefs at a variable" ) @@ -537,29 +537,29 @@ def __post_init__(self): object.__setattr__(self, "value", jax.device_put(self.value)) - def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: + def __getitem__(self, names: Tuple[Any, Any]) -> jnp.ndarray: """Function to query messages from a factor to a variable Args: - keys: a tuple of length 2, with keys[0] being the key for - factor, and keys[1] being the key for variable + names: a tuple of length 2, with names[0] being the name for + factor, and names[1] being the name for variable Returns: An array containing the current ftov messages from factor - keys[0] to variable keys[1] + names[0] to variable names[1] """ if not ( - isinstance(keys, tuple) - and len(keys) == 2 - and keys[1] in self.fg_state.variable_group.keys + isinstance(names, tuple) + and len(names) == 2 + and names[1] in self.fg_state.variable_group.names ): raise ValueError( - f"Invalid keys {keys}. Please specify a tuple of factor, variable " - "keys to get the messages from a named factor to a variable" + f"Invalid names {names}. Please specify a tuple of factor, variable " + "names to get the messages from a named factor to a variable" ) - factor = self.fg_state.variables_to_factors[frozenset(keys[0])] - variable = self.fg_state.variable_group[keys[1]] + factor = self.fg_state.variables_to_factors[frozenset(names[0])] + variable = self.fg_state.variable_group[names[1]] start = self.fg_state.factor_to_msgs_starts[factor] + np.sum( factor.edges_num_states[: factor.variables.index(variable)] ) @@ -569,47 +569,47 @@ def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: @typing.overload def __setitem__( self, - keys: Tuple[Any, Any], + names: Tuple[Any, Any], data: Union[np.ndarray, jnp.ndarray], ) -> None: """Setting messages from a factor to a variable Args: - keys: A tuple of length 2 - keys[0] is the key of the factor - keys[1] is the key of the variable - data: An array containing messages from factor keys[0] - to variable keys[1] + names: A tuple of length 2 + names[0] is the name of the factor + names[1] is the name of the variable + data: An array containing messages from factor names[0] + to variable names[1] """ @typing.overload def __setitem__( self, - keys: Any, + names: Any, data: Union[np.ndarray, jnp.ndarray], ) -> None: """Spreading beliefs at a variable to all connected factors Args: - keys: The key of the variable + names: The name of the variable data: An array containing the beliefs to be spread uniformly across all factor to variable messages involving this variable. """ - def __setitem__(self, keys, data) -> None: + def __setitem__(self, names, data) -> None: if ( - isinstance(keys, tuple) - and len(keys) == 2 - and keys[1] in self.fg_state.variable_group.keys + isinstance(names, tuple) + and len(names) == 2 + and names[1] in self.fg_state.variable_group.names ): - keys = (frozenset(keys[0]), keys[1]) + names = (frozenset(names[0]), names[1]) object.__setattr__( self, "value", update_ftov_msgs( - jax.device_put(self.value), {keys: jax.device_put(data)}, self.fg_state + jax.device_put(self.value), {names: jax.device_put(data)}, self.fg_state ), ) @@ -628,16 +628,16 @@ def update_evidence( Returns: A flat jnp array containing updated evidence. """ - for key in updates: - data = updates[key] - if key in fg_state.variable_group.container_keys: - if key is None: + for name in updates: + data = updates[name] + if name in fg_state.variable_group.container_names: + if name is None: variable_group = fg_state.variable_group else: assert isinstance( fg_state.variable_group, groups.CompositeVariableGroup ) - variable_group = fg_state.variable_group.variable_group_container[key] + variable_group = fg_state.variable_group.variable_group_container[name] start_index = fg_state.vars_to_starts[variable_group.variables[0]] flat_data = variable_group.flatten(data) @@ -645,7 +645,7 @@ def update_evidence( flat_data ) else: - var = fg_state.variable_group[key] + var = fg_state.variable_group[name] start_index = fg_state.vars_to_starts[var] evidence = evidence.at[start_index : start_index + var.num_states].set(data) @@ -676,40 +676,40 @@ def __post_init__(self): object.__setattr__(self, "value", jax.device_put(self.value)) - def __getitem__(self, key: Any) -> jnp.ndarray: + def __getitem__(self, name: Any) -> jnp.ndarray: """Function to query evidence for a variable Args: - key: key for the variable + name: name for the variable Returns: evidence for the queried variable """ - variable = self.fg_state.variable_group[key] + variable = self.fg_state.variable_group[name] start = self.fg_state.vars_to_starts[variable] evidence = jax.device_put(self.value)[start : start + variable.num_states] return evidence def __setitem__( self, - key: Any, + name: Any, data: np.ndarray, ) -> None: """Function to update the evidence for variables Args: - key: The name of a variable group or a single variable. - If key is the name of a variable group, updates are derived by using the variable group to + name: The name of a variable group or a single variable. + If name is the name of a variable group, updates are derived by using the variable group to flatten the data. - If key is the name of a variable, data should be of an array shape (variable_size,) - If key is None, updates are derived by using self.fg_state.variable_group to flatten the data. + If name is the name of a variable, data should be of an array shape (variable_size,) + If name is None, updates are derived by using self.fg_state.variable_group to flatten the data. data: Array containing the evidence updates. """ object.__setattr__( self, "value", update_evidence( - jax.device_put(self.value), {key: jax.device_put(data)}, self.fg_state + jax.device_put(self.value), {name: jax.device_put(data)}, self.fg_state ), ) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index c5970008..4a8df79c 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -31,79 +31,80 @@ class VariableGroup: """Class to represent a group of variables. All variables in the group are assumed to have the same size. Additionally, the - variables are indexed by a "key", and can be retrieved by direct indexing (even indexing - a sequence of keys) of the VariableGroup. + variables are indexed by a variable name, and can be retrieved by direct indexing (even indexing + a sequence of variable names) of the VariableGroup. Attributes: - _keys_to_vars: A private, immutable mapping from keys to variables + _names_to_variables: A private, immutable mapping from variable names to variables """ - _keys_to_vars: Mapping[Hashable, nodes.Variable] = field(init=False) + _names_to_variables: Mapping[Hashable, nodes.Variable] = field(init=False) def __post_init__(self) -> None: - """Initialize a private, immutable mapping from keys to variables.""" + """Initialize a private, immutable mapping from variable names to variables.""" object.__setattr__( - self, "_keys_to_vars", MappingProxyType(self._get_keys_to_vars()) + self, + "_names_to_variables", + MappingProxyType(self._get_names_to_variables()), ) @typing.overload - def __getitem__(self, key: Hashable) -> nodes.Variable: + def __getitem__(self, name: Hashable) -> nodes.Variable: """This function is a typing overload and is overwritten by the implemented __getitem__""" @typing.overload - def __getitem__(self, key: List) -> List[nodes.Variable]: + def __getitem__(self, name: List) -> List[nodes.Variable]: """This function is a typing overload and is overwritten by the implemented __getitem__""" - def __getitem__(self, key): - """Given a key, retrieve the associated Variable. + def __getitem__(self, name): + """Given a name, retrieve the associated Variable. Args: - key: a single key corresponding to a single variable, or a list of such keys + name: a single name corresponding to a single variable, or a list of such names Returns: - a single variable if the "key" argument is a single key. Otherwise, returns a list of - variables corresponding to each key in the "key" argument. + A single variable if the name is not a list. A list of variables if name is a list """ - if isinstance(key, List): - keys_list = key + if isinstance(name, List): + names_list = name else: - keys_list = [key] + names_list = [name] vars_list = [] - for curr_key in keys_list: - var = self._keys_to_vars.get(curr_key) + for curr_name in names_list: + var = self._names_to_variables.get(curr_name) if var is None: raise ValueError( - f"The key {curr_key} is not present in the VariableGroup {type(self)}; please ensure " + f"The name {curr_name} is not present in the VariableGroup {type(self)}; please ensure " "it's been added to the VariableGroup before trying to query it." ) vars_list.append(var) - if isinstance(key, List): + if isinstance(name, List): return vars_list else: return vars_list[0] - def _get_keys_to_vars(self) -> OrderedDict[Any, nodes.Variable]: - """Function that generates a dictionary mapping keys to variables. + def _get_names_to_variables(self) -> OrderedDict[Any, nodes.Variable]: + """Function that generates a dictionary mapping names to variables. Returns: - a dictionary mapping all possible keys to different variables. + a dictionary mapping all possible names to different variables. """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" ) @cached_property - def keys(self) -> Tuple[Any, ...]: - """Function to return a tuple of all keys in the group. + def names(self) -> Tuple[Any, ...]: + """Function to return a tuple of all names in the group. Returns: - tuple of all keys that are part of this VariableGroup + tuple of all names that are part of this VariableGroup """ - return tuple(self._keys_to_vars.keys()) + return tuple(self._names_to_variables.keys()) @cached_property def variables(self) -> Tuple[nodes.Variable, ...]: @@ -112,10 +113,10 @@ def variables(self) -> Tuple[nodes.Variable, ...]: Returns: tuple of all variable that are part of this VariableGroup """ - return tuple(self._keys_to_vars.values()) + return tuple(self._names_to_variables.values()) @cached_property - def container_keys(self) -> Tuple: + def container_names(self) -> Tuple: """Placeholder function. Returns a tuple containing None for all variable groups other than a composite variable group """ @@ -154,17 +155,17 @@ class CompositeVariableGroup(VariableGroup): This class enables users to wrap various different VariableGroups and then index them in a straightforward manner. To index into a CompositeVariableGroup, simply - provide the "key" of the VariableGroup within this CompositeVariableGroup followed - by the key to be indexed within the VariableGroup. + provide the name of the VariableGroup within this CompositeVariableGroup followed + by the name to be indexed within the VariableGroup. Args: variable_group_container: A container containing multiple variable groups. Supported containers include mapping and sequence. - For a mapping, the keys of the mapping are used to index the variable groups. + For a mapping, the names of the mapping are used to index the variable groups. For a sequence, the indices of the sequence are used to index the variable groups. Attributes: - _keys_to_vars: A private, immutable mapping from keys to variables + _names_to_variables: A private, immutable mapping from names to variables """ variable_group_container: Union[ @@ -173,70 +174,79 @@ class CompositeVariableGroup(VariableGroup): def __post_init__(self): object.__setattr__( - self, "_keys_to_vars", MappingProxyType(self._get_keys_to_vars()) + self, + "_names_to_variables", + MappingProxyType(self._get_names_to_variables()), ) @typing.overload - def __getitem__(self, key: Hashable) -> nodes.Variable: + def __getitem__(self, name: Hashable) -> nodes.Variable: """This function is a typing overload and is overwritten by the implemented __getitem__""" @typing.overload - def __getitem__(self, key: List) -> List[nodes.Variable]: + def __getitem__(self, name: List) -> List[nodes.Variable]: """This function is a typing overload and is overwritten by the implemented __getitem__""" - def __getitem__(self, key): - """Given a key, retrieve the associated Variable from the associated VariableGroup. + def __getitem__(self, name): + """Given a name, retrieve the associated Variable from the associated VariableGroup. Args: - key: a single key corresponding to a single Variable within a VariableGroup, or a list - of such keys + name: a single name corresponding to a single Variable within a VariableGroup, or a list + of such names Returns: - a single variable if the "key" argument is a single key. Otherwise, returns a list of - variables corresponding to each key in the "key" argument. + A single variable if the name is not a list. A list of variables if name is a list """ - if isinstance(key, List): - keys_list = key + if isinstance(name, List): + names_list = name else: - keys_list = [key] + names_list = [name] vars_list = [] - for curr_key in keys_list: - if len(curr_key) < 2: + for curr_name in names_list: + if len(curr_name) < 2: raise ValueError( - "The key needs to have at least 2 elements to index from a composite variable group." + "The name needs to have at least 2 elements to index from a composite variable group." ) - variable_group = self.variable_group_container[curr_key[0]] - if len(curr_key) == 2: - vars_list.append(variable_group[curr_key[1]]) + variable_group = self.variable_group_container[curr_name[0]] + if len(curr_name) == 2: + vars_list.append(variable_group[curr_name[1]]) else: - vars_list.append(variable_group[curr_key[1:]]) + vars_list.append(variable_group[curr_name[1:]]) - if isinstance(key, List): + if isinstance(name, List): return vars_list else: return vars_list[0] - def _get_keys_to_vars(self) -> OrderedDict[Hashable, nodes.Variable]: - """Function that generates a dictionary mapping keys to variables. + def _get_names_to_variables(self) -> OrderedDict[Hashable, nodes.Variable]: + """Function that generates a dictionary mapping names to variables. Returns: - a dictionary mapping all possible keys to different variables. + a dictionary mapping all possible names to different variables. """ - keys_to_vars: OrderedDict[Hashable, nodes.Variable] = collections.OrderedDict() - for container_key in self.container_keys: - for variable_group_key in self.variable_group_container[container_key].keys: - if isinstance(variable_group_key, tuple): - keys_to_vars[ - (container_key,) + variable_group_key - ] = self.variable_group_container[container_key][variable_group_key] + names_to_variables: OrderedDict[ + Hashable, nodes.Variable + ] = collections.OrderedDict() + for container_name in self.container_names: + for variable_group_name in self.variable_group_container[ + container_name + ].names: + if isinstance(variable_group_name, tuple): + names_to_variables[ + (container_name,) + variable_group_name + ] = self.variable_group_container[container_name][ + variable_group_name + ] else: - keys_to_vars[ - (container_key, variable_group_key) - ] = self.variable_group_container[container_key][variable_group_key] + names_to_variables[ + (container_name, variable_group_name) + ] = self.variable_group_container[container_name][ + variable_group_name + ] - return keys_to_vars + return names_to_variables def flatten(self, data: Union[Mapping, Sequence]) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -251,8 +261,8 @@ def flatten(self, data: Union[Mapping, Sequence]) -> jnp.ndarray: """ flat_data = jnp.concatenate( [ - self.variable_group_container[key].flatten(data[key]) - for key in self.container_keys + self.variable_group_container[name].flatten(data[name]) + for name in self.container_names ] ) return flat_data @@ -275,8 +285,8 @@ def unflatten( num_variables = 0 num_variable_states = 0 - for key in self.container_keys: - variable_group = self.variable_group_container[key] + for name in self.container_names: + variable_group = self.variable_group_container[name] num_variables += len(variable_group.variables) num_variable_states += ( len(variable_group.variables) * variable_group.variables[0].num_states @@ -295,8 +305,8 @@ def unflatten( data: List[np.ndarray] = [] start = 0 - for key in self.container_keys: - variable_group = self.variable_group_container[key] + for name in self.container_names: + variable_group = self.variable_group_container[name] length = len(variable_group.variables) if use_num_states: length *= variable_group.variables[0].num_states @@ -304,25 +314,27 @@ def unflatten( data.append(variable_group.unflatten(flat_data[start : start + length])) start += length if isinstance(self.variable_group_container, Mapping): - return dict([(key, data[kk]) for kk, key in enumerate(self.container_keys)]) + return dict( + [(name, data[kk]) for kk, name in enumerate(self.container_names)] + ) else: return data @cached_property - def container_keys(self) -> Tuple: - """Function to get keys referring to the variable groups within this + def container_names(self) -> Tuple: + """Function to get names referring to the variable groups within this CompositeVariableGroup. Returns: - a tuple of the keys referring to the variable groups within this + a tuple of the names referring to the variable groups within this CompositeVariableGroup. """ if isinstance(self.variable_group_container, Mapping): - container_keys = tuple(self.variable_group_container.keys()) + container_names = tuple(self.variable_group_container.keys()) else: - container_keys = tuple(range(len(self.variable_group_container))) + container_names = tuple(range(len(self.variable_group_container))) - return container_keys + return container_names @dataclass(frozen=True, eq=False) @@ -338,24 +350,24 @@ class NDVariableArray(VariableGroup): variable_size: int shape: Tuple[int, ...] - def _get_keys_to_vars( + def _get_names_to_variables( self, ) -> OrderedDict[Union[int, Tuple[int, ...]], nodes.Variable]: - """Function that generates a dictionary mapping keys to variables. + """Function that generates a dictionary mapping names to variables. Returns: - a dictionary mapping all possible keys to different variables. + a dictionary mapping all possible names to different variables. """ - keys_to_vars: OrderedDict[ + names_to_variables: OrderedDict[ Union[int, Tuple[int, ...]], nodes.Variable ] = collections.OrderedDict() - for key in itertools.product(*[list(range(k)) for k in self.shape]): - if len(key) == 1: - keys_to_vars[key[0]] = nodes.Variable(self.variable_size) + for name in itertools.product(*[list(range(k)) for k in self.shape]): + if len(name) == 1: + names_to_variables[name[0]] = nodes.Variable(self.variable_size) else: - keys_to_vars[key] = nodes.Variable(self.variable_size) + names_to_variables[name] = nodes.Variable(self.variable_size) - return keys_to_vars + return names_to_variables def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -420,19 +432,19 @@ class VariableDict(VariableGroup): variable_size: int variable_names: Tuple[Any, ...] - def _get_keys_to_vars(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: - """Function that generates a dictionary mapping keys to variables. + def _get_names_to_variables(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: + """Function that generates a dictionary mapping names to variables. Returns: - a dictionary mapping all possible keys to different variables. + a dictionary mapping all possible names to different variables. """ - keys_to_vars: OrderedDict[ + names_to_variables: OrderedDict[ Tuple[Any, ...], nodes.Variable ] = collections.OrderedDict() - for key in self.variable_names: - keys_to_vars[key] = nodes.Variable(self.variable_size) + for name in self.variable_names: + names_to_variables[name] = nodes.Variable(self.variable_size) - return keys_to_vars + return names_to_variables def flatten( self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] @@ -440,24 +452,26 @@ def flatten( """Function that turns meaningful structured data into a flat data array for internal use. Args: - data: Meaningful structured data. Should be a mapping with keys from self.variable_names. + data: Meaningful structured data. Should be a mapping with names from self.variable_names. Each value should be an array of shape (1,) (for e.g. MAP decodings) or (self.variable_size,) (for e.g. evidence, beliefs). Returns: A flat jnp.array for internal use """ - for key in data: - if key not in self._keys_to_vars: - raise ValueError(f"data is referring to a non-existent variable {key}.") + for name in data: + if name not in self._names_to_variables: + raise ValueError( + f"data is referring to a non-existent variable {name}." + ) - if data[key].shape != (self.variable_size,) and data[key].shape != (1,): + if data[name].shape != (self.variable_size,) and data[name].shape != (1,): raise ValueError( - f"Variable {key} expects a data array of shape " - f"{(self.variable_size,)} or (1,). Got {data[key].shape}." + f"Variable {name} expects a data array of shape " + f"{(self.variable_size,)} or (1,). Got {data[name].shape}." ) - flat_data = jnp.concatenate([data[key].flatten() for key in self.keys]) + flat_data = jnp.concatenate([data[name].flatten() for name in self.names]) return flat_data def unflatten( @@ -469,7 +483,7 @@ def unflatten( flat_data: Internal flat data array. Returns: - Meaningful structured data. Should be a mapping with keys from self.variable_names. + Meaningful structured data. Should be a mapping with names from self.variable_names. Each value should be an array of shape (1,) (for e.g. MAP decodings) or (self.variable_size,) (for e.g. evidence, beliefs). @@ -494,12 +508,12 @@ def unflatten( start = 0 data = {} - for key in self.variable_names: + for name in self.variable_names: if use_num_states: - data[key] = flat_data[start : start + self.variable_size] + data[name] = flat_data[start : start + self.variable_size] start += self.variable_size else: - data[key] = flat_data[np.array([start])] + data[name] = flat_data[np.array([start])] start += 1 return data @@ -518,7 +532,7 @@ class FactorGroup: _variables_to_factors: maps set of involved variables to the corresponding factors Raises: - ValueError: if connected_var_keys is an empty list + ValueError: if connected_var_names is an empty list """ variable_group: Union[CompositeVariableGroup, VariableGroup] @@ -568,10 +582,10 @@ def factor_group_log_potentials(self) -> np.ndarray: def _get_variables_to_factors( self, ) -> OrderedDict[FrozenSet, nodes.EnumerationFactor]: - """Function that generates a dictionary mapping keys to factors. + """Function that generates a dictionary mapping names to factors. Returns: - a dictionary mapping all possible keys to different factors. + a dictionary mapping all possible names to different factors. """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" @@ -626,8 +640,8 @@ class EnumerationFactorGroup(FactorGroup): uniform 0 unless the inheriting class includes a log_potentials argument. Args: - connected_var_keys: A list of list of tuples, where each innermost tuple contains a - key into variable_group. Each list within the outer list is taken to contain the keys of variables + connected_var_names: A list of list of tuples, where each innermost tuple contains a + name into variable_group. Each list within the outer list is taken to contain the names of variables neighboring a particular factor to be added. factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations @@ -637,7 +651,7 @@ class EnumerationFactorGroup(FactorGroup): initialized. """ - connected_var_keys: Sequence[List] + connected_var_names: Sequence[List] factor_configs: np.ndarray log_potentials: Optional[np.ndarray] = None @@ -649,7 +663,7 @@ def _get_variables_to_factors( Returns: a dictionary mapping all possible set of involved variables to different factors. """ - num_factors = len(self.connected_var_keys) + num_factors = len(self.connected_var_names) num_val_configs = self.factor_configs.shape[0] if self.log_potentials is None: log_potentials = np.zeros((num_factors, num_val_configs), dtype=float) @@ -672,14 +686,14 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(self.connected_var_keys[ii]), + frozenset(self.connected_var_names[ii]), nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[ii]]), + tuple(self.variable_group[self.connected_var_names[ii]]), self.factor_configs, log_potentials[ii], ), ) - for ii in range(len(self.connected_var_keys)) + for ii in range(len(self.connected_var_names)) ] ) return variables_to_factors @@ -765,15 +779,15 @@ class PairwiseFactorGroup(FactorGroup): one CompositeVariableGroup. Args: - connected_var_keys: A list of list of tuples, where each innermost tuple contains a - key into variable_group. Each list within the outer list is taken to contain the keys of variables + connected_var_names: A list of list of tuples, where each innermost tuple contains a + name into variable_group. Each list within the outer list is taken to contain the names of variables neighboring a particular factor to be added. log_potential_matrix: array of shape (var1.variable_size, var2.variable_size), where var1 and var2 are the 2 VariableGroups (that may refer to the same - VariableGroup) whose keys are present in each sub-list from self.connected_var_keys. + VariableGroup) whose names are present in each sub-list from self.connected_var_names. """ - connected_var_keys: Sequence[List] + connected_var_names: Sequence[List] log_potential_matrix: Optional[np.ndarray] = None def _get_variables_to_factors( @@ -785,15 +799,15 @@ def _get_variables_to_factors( a dictionary mapping all possible set of involved variables to different factors. Raises: - ValueError: if every sub-list within self.connected_var_keys has len != 2, or if the shape of the + ValueError: if every sub-list within self.connected_var_names has len != 2, or if the shape of the log_potential_matrix is not the same as the variable sizes for each variable referenced in - each sub-list of self.connected_var_keys + each sub-list of self.connected_var_names """ if self.log_potential_matrix is None: log_potential_matrix = np.zeros( ( - self.variable_group[self.connected_var_keys[0][0]].num_states, - self.variable_group[self.connected_var_keys[0][1]].num_states, + self.variable_group[self.connected_var_names[0][0]].num_states, + self.variable_group[self.connected_var_names[0][1]].num_states, ) ) else: @@ -807,14 +821,14 @@ def _get_variables_to_factors( ) if log_potential_matrix.ndim == 3 and log_potential_matrix.shape[0] != len( - self.connected_var_keys + self.connected_var_names ): raise ValueError( - f"Expected log_potential_matrix for {len(self.connected_var_keys)} factors. " + f"Expected log_potential_matrix for {len(self.connected_var_names)} factors. " f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors." ) - for fac_list in self.connected_var_keys: + for fac_list in self.connected_var_names: if len(fac_list) != 2: raise ValueError( "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to" @@ -846,21 +860,21 @@ def _get_variables_to_factors( object.__setattr__(self, "log_potential_matrix", log_potential_matrix) log_potential_matrix = np.broadcast_to( log_potential_matrix, - (len(self.connected_var_keys),) + log_potential_matrix.shape[-2:], + (len(self.connected_var_names),) + log_potential_matrix.shape[-2:], ) variables_to_factors = collections.OrderedDict( [ ( - frozenset(self.connected_var_keys[ii]), + frozenset(self.connected_var_names[ii]), nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_keys[ii]]), + tuple(self.variable_group[self.connected_var_names[ii]]), factor_configs, log_potential_matrix[ ii, factor_configs[:, 0], factor_configs[:, 1] ], ), ) - for ii in range(len(self.connected_var_keys)) + for ii in range(len(self.connected_var_names)) ] ) return variables_to_factors diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index b6ab4c1c..a3c9794d 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -65,7 +65,7 @@ def test_log_potentials(): with pytest.raises( ValueError, - match=re.escape(f"Invalid key {frozenset([1])} for log potentials updates."), + match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."), ): fg.bp_state.log_potentials[frozenset([1])] = np.zeros(10) @@ -79,7 +79,7 @@ def test_log_potentials(): assert jnp.all(log_potentials[[0]] == jnp.zeros(10)) with pytest.raises( ValueError, - match=re.escape(f"Invalid key {frozenset([1])} for log potentials updates."), + match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."), ): fg.bp_state.log_potentials[[1]] @@ -106,7 +106,7 @@ def test_ftov_msgs(): with pytest.raises( ValueError, - match=re.escape("Invalid keys for setting messages"), + match=re.escape("Invalid names for setting messages"), ): fg.bp_state.ftov_msgs[1] = np.ones(10) @@ -116,7 +116,7 @@ def test_ftov_msgs(): graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) - with pytest.raises(ValueError, match=re.escape("Invalid keys (10,)")): + with pytest.raises(ValueError, match=re.escape("Invalid names (10,)")): ftov_msgs[(10,)] diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 4380bb80..f7c3e2c9 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -17,7 +17,7 @@ def test_composite_variable_group(): composite_variable_dict = groups.CompositeVariableGroup( {(0, 1): variable_dict1, (2, 3): variable_dict2} ) - with pytest.raises(ValueError, match="The key needs to have at least 2 elements"): + with pytest.raises(ValueError, match="The name needs to have at least 2 elements"): composite_variable_sequence[(0,)] assert composite_variable_sequence[0, 1] == variable_dict1[1] @@ -32,12 +32,12 @@ def test_composite_variable_group(): ] assert jnp.all( composite_variable_sequence.flatten( - [{key: np.zeros(15) for key in range(3)} for _ in range(2)] + [{name: np.zeros(15) for name in range(3)} for _ in range(2)] ) == composite_variable_dict.flatten( { - (0, 1): {key: np.zeros(15) for key in range(3)}, - (2, 3): {key: np.zeros(15) for key in range(3)}, + (0, 1): {name: np.zeros(15) for name in range(3)}, + (2, 3): {name: np.zeros(15) for name in range(3)}, } ) ) @@ -47,7 +47,7 @@ def test_composite_variable_group(): jax.tree_util.tree_multimap( lambda x, y: jnp.all(x == y), composite_variable_sequence.unflatten(jnp.zeros(15 * 3 * 2)), - [{key: jnp.zeros(15) for key in range(3)} for _ in range(2)], + [{name: jnp.zeros(15) for name in range(3)} for _ in range(2)], ) ) ) @@ -59,8 +59,8 @@ def test_composite_variable_group(): lambda x, y: jnp.all(x == y), composite_variable_dict.unflatten(jnp.zeros(3 * 2)), { - (0, 1): {key: np.zeros(1) for key in range(3)}, - (2, 3): {key: np.zeros(1) for key in range(3)}, + (0, 1): {name: np.zeros(1) for name in range(3)}, + (2, 3): {name: np.zeros(1) for name in range(3)}, }, ) ) @@ -106,7 +106,7 @@ def test_variable_dict(): jax.tree_util.tree_multimap( lambda x, y: jnp.all(x == y), variable_dict.unflatten(jnp.zeros(3)), - {key: np.zeros(1) for key in range(3)}, + {name: np.zeros(1) for name in range(3)}, ) ) ) @@ -158,24 +158,24 @@ def test_enumeration_factor_group(): ): enumeration_factor_group = groups.EnumerationFactorGroup( variable_group=variable_group, - connected_var_keys=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], + connected_var_names=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], factor_configs=np.zeros((1, 3), dtype=int), log_potentials=np.zeros((3, 2)), ) enumeration_factor_group = groups.EnumerationFactorGroup( variable_group=variable_group, - connected_var_keys=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], + connected_var_names=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], factor_configs=np.zeros((1, 3), dtype=int), ) - key = [(0, 0), (1, 1)] + name = [(0, 0), (1, 1)] with pytest.raises( ValueError, match=re.escape( - f"The queried factor {frozenset(key)} is not present in the factor group." + f"The queried factor {frozenset(name)} is not present in the factor group." ), ): - enumeration_factor_group[key] + enumeration_factor_group[name] assert ( enumeration_factor_group[[(0, 1), (1, 0), (1, 1)]] diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 646d204c..5df6bdfa 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -213,14 +213,14 @@ def create_valid_suppression_config_arr(suppression_diameter): grid_vars_group = groups.NDVariableArray(3, (2, M - 1, N - 1)) # Make a group of additional variables for the edges of the grid - extra_row_keys: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)] - extra_col_keys: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)] - additional_keys = tuple(extra_row_keys + extra_col_keys) - additional_keys_group = groups.VariableDict(3, additional_keys) + extra_row_names: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)] + extra_col_names: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)] + additional_names = tuple(extra_row_names + extra_col_names) + additional_names_group = groups.VariableDict(3, additional_names) # Combine these two VariableGroups into one CompositeVariableGroup composite_grid_group = groups.CompositeVariableGroup( - {"grid_vars": grid_vars_group, "additional_vars": additional_keys_group} + {"grid_vars": grid_vars_group, "additional_vars": additional_names_group} ) gt_has_cuts = gt_has_cuts.astype(np.int32) @@ -262,14 +262,14 @@ def create_valid_suppression_config_arr(suppression_diameter): for row in range(M - 1): for col in range(N - 1): if row != M - 2 and col != N - 2: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("grid_vars", 0, row, col + 1), ("grid_vars", 1, row + 1, col), ] elif row != M - 2: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("additional_vars", 0, row, col + 1), @@ -277,7 +277,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ] elif col != N - 2: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("grid_vars", 0, row, col + 1), @@ -285,7 +285,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ] else: - curr_keys = [ + curr_names = [ ("grid_vars", 0, row, col), ("grid_vars", 1, row, col), ("additional_vars", 0, row, col + 1), @@ -293,14 +293,14 @@ def create_valid_suppression_config_arr(suppression_diameter): ] if row % 2 == 0: fg.add_factor( - curr_keys, + curr_names, valid_configs_non_supp, np.zeros(valid_configs_non_supp.shape[0], dtype=float), name=(row, col), ) else: fg.add_factor( - variable_names=curr_keys, + variable_names=curr_names, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float @@ -309,36 +309,36 @@ def create_valid_suppression_config_arr(suppression_diameter): ) # Create an EnumerationFactorGroup for vertical suppression factors - vert_suppression_keys: List[List[Tuple[Any, ...]]] = [] + vert_suppression_names: List[List[Tuple[Any, ...]]] = [] for col in range(N): for start_row in range(M - SUPPRESSION_DIAMETER): if col != N - 1: - vert_suppression_keys.append( + vert_suppression_names.append( [ ("grid_vars", 0, r, col) for r in range(start_row, start_row + SUPPRESSION_DIAMETER) ] ) else: - vert_suppression_keys.append( + vert_suppression_names.append( [ ("additional_vars", 0, r, col) for r in range(start_row, start_row + SUPPRESSION_DIAMETER) ] ) - horz_suppression_keys: List[List[Tuple[Any, ...]]] = [] + horz_suppression_names: List[List[Tuple[Any, ...]]] = [] for row in range(M): for start_col in range(N - SUPPRESSION_DIAMETER): if row != M - 1: - horz_suppression_keys.append( + horz_suppression_names.append( [ ("grid_vars", 1, row, c) for c in range(start_col, start_col + SUPPRESSION_DIAMETER) ] ) else: - horz_suppression_keys.append( + horz_suppression_names.append( [ ("additional_vars", 1, row, c) for c in range(start_col, start_col + SUPPRESSION_DIAMETER) @@ -348,14 +348,14 @@ def create_valid_suppression_config_arr(suppression_diameter): # Add the suppression factors to the graph via kwargs fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_keys={ - idx: keys for idx, keys in enumerate(vert_suppression_keys) + connected_var_names={ + idx: names for idx, names in enumerate(vert_suppression_names) }, factor_configs=valid_configs_supp, ) fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_keys=horz_suppression_keys, + connected_var_names=horz_suppression_names, factor_configs=valid_configs_supp, log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), ) @@ -370,8 +370,8 @@ def create_valid_suppression_config_arr(suppression_diameter): # Test that the output messages are close to the true messages assert jnp.allclose(bp_arrays.ftov_msgs, true_final_msgs_output, atol=1e-06) decoded_map_states = graph.decode_map_states(get_beliefs(bp_arrays)) - for key in true_map_state_output: - assert true_map_state_output[key] == decoded_map_states[key[0]][key[1:]] + for name in true_map_state_output: + assert true_map_state_output[name] == decoded_map_states[name[0]][name[1:]] def test_e2e_heretic(): @@ -407,7 +407,7 @@ def binary_connected_variables( for k_col in range(3): fg.add_factor_group( factory=groups.PairwiseFactorGroup, - connected_var_keys=binary_connected_variables(28, 28, k_row, k_col), + connected_var_names=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], name=(k_row, k_col), ) From bd8899c7579b3bae6e9061733d539781adf07b97 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 20:38:13 -0700 Subject: [PATCH 48/56] More renaming --- examples/heretic_example.py | 2 +- examples/ising_model.py | 8 ++--- examples/sanity_check_example.py | 4 +-- pgmax/fg/graph.py | 10 +++--- pgmax/fg/groups.py | 56 ++++++++++++++++---------------- pgmax/fg/nodes.py | 2 +- tests/fg/test_groups.py | 7 ++-- tests/test_pgmax.py | 8 +++-- 8 files changed, 51 insertions(+), 46 deletions(-) diff --git a/examples/heretic_example.py b/examples/heretic_example.py index 1a888a7a..87a7272b 100644 --- a/examples/heretic_example.py +++ b/examples/heretic_example.py @@ -117,7 +117,7 @@ def binary_connected_variables( for k_col in range(3): fg.add_factor_group( factory=groups.PairwiseFactorGroup, - connected_var_names=binary_connected_variables(28, 28, k_row, k_col), + connected_variable_names=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], ) diff --git a/examples/ising_model.py b/examples/ising_model.py index 27634072..6f295ee3 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -28,17 +28,17 @@ # %% variables = groups.NDVariableArray(variable_size=2, shape=(50, 50)) fg = graph.FactorGraph(variables=variables) -connected_var_names = [] +connected_variable_names = [] for ii in range(50): for jj in range(50): kk = (ii + 1) % 50 ll = (jj + 1) % 50 - connected_var_names.append([(ii, jj), (kk, jj)]) - connected_var_names.append([(ii, jj), (ii, ll)]) + connected_variable_names.append([(ii, jj), (kk, jj)]) + connected_variable_names.append([(ii, jj), (ii, ll)]) fg.add_factor_group( factory=groups.PairwiseFactorGroup, - connected_var_names=connected_var_names, + connected_variable_names=connected_variable_names, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), name="factors", ) diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index cbdcec3d..fa300df5 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -315,12 +315,12 @@ def create_valid_suppression_config_arr(suppression_diameter): # %% fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_names=vert_suppression_names, + connected_variable_names=vert_suppression_names, factor_configs=valid_configs_supp, ) fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_names=horz_suppression_names, + connected_variable_names=horz_suppression_names, factor_configs=valid_configs_supp, ) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 98ea5709..692464d4 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -104,7 +104,7 @@ def add_factor( """Function to add a single factor to the FactorGraph. Args: - variable_names: A list containing the involved variable names. + variable_names: A list containing the connected variable names. factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations log_potentials: Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). @@ -114,7 +114,7 @@ def add_factor( """ factor_group = groups.EnumerationFactorGroup( self._variable_group, - connected_var_names=[variable_names], + connected_variable_names=[variable_names], factor_configs=factor_configs, log_potentials=log_potentials, ) @@ -260,7 +260,7 @@ class FactorGraphState: contains evidence to the variable. num_var_states: Total number of variable states. total_factor_num_states: Size of the flat ftov messages array. - variables_to_factors: Maps sets of involved variables (in the form of frozensets of + variables_to_factors: Maps sets of connected variables (in the form of frozensets of variable names) to corresponding factors. named_factor_groups: Maps the names of named factor groups to the corresponding factor groups. factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials. @@ -399,7 +399,7 @@ def __getitem__(self, name: Any): Args: name: Name of a named factor group, or a frozenset containing the set - of involved variables for the queried factor. + of connected variables for the queried factor. Returned: The quried log potentials. @@ -433,7 +433,7 @@ def __setitem__( Args: name: Name of a named factor group, or a frozenset containing the set - of involved variables for the queried factor. + of connected variables for the queried factor. data: Array containing the log potentials for the named factor group or the factor. """ diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 4a8df79c..509fc23c 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -529,10 +529,10 @@ class FactorGroup: all the variables that are connected to this FactorGroup Attributes: - _variables_to_factors: maps set of involved variables to the corresponding factors + _variables_to_factors: maps set of connected variables to the corresponding factors Raises: - ValueError: if connected_var_names is an empty list + ValueError: if connected_variable_names is an empty list """ variable_group: Union[CompositeVariableGroup, VariableGroup] @@ -640,9 +640,9 @@ class EnumerationFactorGroup(FactorGroup): uniform 0 unless the inheriting class includes a log_potentials argument. Args: - connected_var_names: A list of list of tuples, where each innermost tuple contains a - name into variable_group. Each list within the outer list is taken to contain the names of variables - neighboring a particular factor to be added. + connected_variable_names: A list of list of variable names, where each innermost element is the + name of a variable in variable_group. Each list within the outer list is taken to contain + the names of the variables connected to a factor. factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations log_potentials: Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). @@ -651,19 +651,19 @@ class EnumerationFactorGroup(FactorGroup): initialized. """ - connected_var_names: Sequence[List] + connected_variable_names: Sequence[List] factor_configs: np.ndarray log_potentials: Optional[np.ndarray] = None def _get_variables_to_factors( self, ) -> OrderedDict[FrozenSet, nodes.EnumerationFactor]: - """Function that generates a dictionary mapping set of involved variables to factors. + """Function that generates a dictionary mapping set of connected variables to factors. Returns: - a dictionary mapping all possible set of involved variables to different factors. + a dictionary mapping all possible set of connected variables to different factors. """ - num_factors = len(self.connected_var_names) + num_factors = len(self.connected_variable_names) num_val_configs = self.factor_configs.shape[0] if self.log_potentials is None: log_potentials = np.zeros((num_factors, num_val_configs), dtype=float) @@ -686,14 +686,14 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(self.connected_var_names[ii]), + frozenset(self.connected_variable_names[ii]), nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_names[ii]]), + tuple(self.variable_group[self.connected_variable_names[ii]]), self.factor_configs, log_potentials[ii], ), ) - for ii in range(len(self.connected_var_names)) + for ii in range(len(self.connected_variable_names)) ] ) return variables_to_factors @@ -779,35 +779,35 @@ class PairwiseFactorGroup(FactorGroup): one CompositeVariableGroup. Args: - connected_var_names: A list of list of tuples, where each innermost tuple contains a + connected_variable_names: A list of list of tuples, where each innermost tuple contains a name into variable_group. Each list within the outer list is taken to contain the names of variables neighboring a particular factor to be added. log_potential_matrix: array of shape (var1.variable_size, var2.variable_size), where var1 and var2 are the 2 VariableGroups (that may refer to the same - VariableGroup) whose names are present in each sub-list from self.connected_var_names. + VariableGroup) whose names are present in each sub-list from self.connected_variable_names. """ - connected_var_names: Sequence[List] + connected_variable_names: Sequence[List] log_potential_matrix: Optional[np.ndarray] = None def _get_variables_to_factors( self, ) -> OrderedDict[FrozenSet, nodes.EnumerationFactor]: - """Function that generates a dictionary mapping set of involved variables to factors. + """Function that generates a dictionary mapping set of connected variables to factors. Returns: - a dictionary mapping all possible set of involved variables to different factors. + a dictionary mapping all possible set of connected variables to different factors. Raises: - ValueError: if every sub-list within self.connected_var_names has len != 2, or if the shape of the + ValueError: if every sub-list within self.connected_variable_names has len != 2, or if the shape of the log_potential_matrix is not the same as the variable sizes for each variable referenced in - each sub-list of self.connected_var_names + each sub-list of self.connected_variable_names """ if self.log_potential_matrix is None: log_potential_matrix = np.zeros( ( - self.variable_group[self.connected_var_names[0][0]].num_states, - self.variable_group[self.connected_var_names[0][1]].num_states, + self.variable_group[self.connected_variable_names[0][0]].num_states, + self.variable_group[self.connected_variable_names[0][1]].num_states, ) ) else: @@ -821,14 +821,14 @@ def _get_variables_to_factors( ) if log_potential_matrix.ndim == 3 and log_potential_matrix.shape[0] != len( - self.connected_var_names + self.connected_variable_names ): raise ValueError( - f"Expected log_potential_matrix for {len(self.connected_var_names)} factors. " + f"Expected log_potential_matrix for {len(self.connected_variable_names)} factors. " f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors." ) - for fac_list in self.connected_var_names: + for fac_list in self.connected_variable_names: if len(fac_list) != 2: raise ValueError( "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to" @@ -860,21 +860,21 @@ def _get_variables_to_factors( object.__setattr__(self, "log_potential_matrix", log_potential_matrix) log_potential_matrix = np.broadcast_to( log_potential_matrix, - (len(self.connected_var_names),) + log_potential_matrix.shape[-2:], + (len(self.connected_variable_names),) + log_potential_matrix.shape[-2:], ) variables_to_factors = collections.OrderedDict( [ ( - frozenset(self.connected_var_names[ii]), + frozenset(self.connected_variable_names[ii]), nodes.EnumerationFactor( - tuple(self.variable_group[self.connected_var_names[ii]]), + tuple(self.variable_group[self.connected_variable_names[ii]]), factor_configs, log_potential_matrix[ ii, factor_configs[:, 0], factor_configs[:, 1] ], ), ) - for ii in range(len(self.connected_var_names)) + for ii in range(len(self.connected_variable_names)) ] ) return variables_to_factors diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 14b4c206..30c2e873 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -62,7 +62,7 @@ class EnumerationFactor: """An enumeration factor Args: - variables: List of involved variables + variables: List of connected variables configs: Array of shape (num_val_configs, num_variables) An array containing an explicit enumeration of all valid configurations log_potentials: Array of shape (num_val_configs,). An array containing diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index f7c3e2c9..dc195ed4 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -158,14 +158,17 @@ def test_enumeration_factor_group(): ): enumeration_factor_group = groups.EnumerationFactorGroup( variable_group=variable_group, - connected_var_names=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], + connected_variable_names=[ + [(0, 0), (0, 1), (1, 1)], + [(0, 1), (1, 0), (1, 1)], + ], factor_configs=np.zeros((1, 3), dtype=int), log_potentials=np.zeros((3, 2)), ) enumeration_factor_group = groups.EnumerationFactorGroup( variable_group=variable_group, - connected_var_names=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], + connected_variable_names=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], factor_configs=np.zeros((1, 3), dtype=int), ) name = [(0, 0), (1, 1)] diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 5df6bdfa..4968c8e1 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -348,14 +348,14 @@ def create_valid_suppression_config_arr(suppression_diameter): # Add the suppression factors to the graph via kwargs fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_names={ + connected_variable_names={ idx: names for idx, names in enumerate(vert_suppression_names) }, factor_configs=valid_configs_supp, ) fg.add_factor_group( factory=groups.EnumerationFactorGroup, - connected_var_names=horz_suppression_names, + connected_variable_names=horz_suppression_names, factor_configs=valid_configs_supp, log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), ) @@ -407,7 +407,9 @@ def binary_connected_variables( for k_col in range(3): fg.add_factor_group( factory=groups.PairwiseFactorGroup, - connected_var_names=binary_connected_variables(28, 28, k_row, k_col), + connected_variable_names=binary_connected_variables( + 28, 28, k_row, k_col + ), log_potential_matrix=W_pot[:, :, k_row, k_col], name=(k_row, k_col), ) From 8833e2f8a002075846e42992833525c202274b91 Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 20:41:31 -0700 Subject: [PATCH 49/56] Try to avoid memory leaks --- pgmax/fg/graph.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 692464d4..e29a58ac 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -756,9 +756,10 @@ def BP(bp_state: BPState, num_iters: int): get_bp_state: Function to reconstruct the BPState from BPArrays. get_beliefs: Function to calculate beliefs from BPArrays. """ - wiring = jax.device_put(bp_state.fg_state.wiring) - max_msg_size = int(jnp.max(wiring.edges_num_states)) - num_val_configs = int(wiring.factor_configs_edge_states[-1, 0]) + 1 + max_msg_size = int(np.max(bp_state.fg_state.wiring.edges_num_states)) + num_val_configs = ( + int(bp_state.fg_state.wiring.factor_configs_edge_states[-1, 0]) + 1 + ) @jax.jit def run_bp( @@ -781,6 +782,7 @@ def run_bp( Returns: A BPArrays containing the updated log_potentials, ftov_msgs and evidence. """ + wiring = jax.device_put(bp_state.fg_state.wiring) log_potentials = jax.device_put(bp_state.log_potentials.value) if log_potentials_updates is not None: log_potentials = update_log_potentials( @@ -864,9 +866,10 @@ def get_beliefs(bp_arrays: BPArrays): Returns: beliefs: An array or a PyTree container containing the beliefs for the variables. """ - evidence = jax.device_put(bp_arrays.evidence) beliefs = bp_state.fg_state.variable_group.unflatten( - evidence.at[wiring.var_states_for_edges].add(bp_arrays.ftov_msgs) + jax.device_put(bp_arrays.evidence) + .at[jax.device_put(bp_state.fg_state.wiring.var_states_for_edges)] + .add(bp_arrays.ftov_msgs) ) return beliefs From e80d4c22f8d9bd40b67ba387d6a07ec12bc0c49a Mon Sep 17 00:00:00 2001 From: stannis Date: Tue, 26 Oct 2021 21:55:27 -0700 Subject: [PATCH 50/56] functools.partial --- pgmax/fg/graph.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index e29a58ac..f3f2146a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -3,6 +3,7 @@ import collections import copy +import functools import typing from dataclasses import asdict, dataclass from types import MappingProxyType @@ -319,7 +320,7 @@ def fg_state(self) -> FactorGraphState: return self.log_potentials.fg_state -@jax.partial(jax.jit, static_argnames="fg_state") +@functools.partial(jax.jit, static_argnames="fg_state") def update_log_potentials( log_potentials: jnp.ndarray, updates: Dict[Any, jnp.ndarray], @@ -449,7 +450,7 @@ def __setitem__( ) -@jax.partial(jax.jit, static_argnames="fg_state") +@functools.partial(jax.jit, static_argnames="fg_state") def update_ftov_msgs( ftov_msgs: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState ) -> jnp.ndarray: @@ -614,7 +615,7 @@ def __setitem__(self, names, data) -> None: ) -@jax.partial(jax.jit, static_argnames="fg_state") +@functools.partial(jax.jit, static_argnames="fg_state") def update_evidence( evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState ) -> jnp.ndarray: From 824a37c101abda62c2aa58a959327d3fef0f0ab6 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 31 Oct 2021 13:06:16 -0700 Subject: [PATCH 51/56] Small fixes --- pgmax/fg/graph.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index f3f2146a..70e7a49e 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -64,6 +64,7 @@ def __post_init__(self): 0, 0, ) + # See FactorGraphState docstrings for documentation on the following fields self._num_var_states = vars_num_states_cumsum[-1] self._vars_to_starts = MappingProxyType( { @@ -137,17 +138,17 @@ def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: def _register_factor_group( self, factor_group: groups.FactorGroup, name: Optional[str] = None ) -> None: - if name in self._named_factor_groups: - raise ValueError( - f"A factor group with the name {name} already exists. Please choose a different name!" - ) - """Register a factor group to the factor graph, by updating the factor graph state. Args: factor_group: The factor group to be registered to the factor graph. name: Optional name of the factor group. """ + if name in self._named_factor_groups: + raise ValueError( + f"A factor group with the name {name} already exists. Please choose a different name!" + ) + self._factor_group_to_msgs_starts[factor_group] = self._total_factor_num_states self._factor_group_to_potentials_starts[ factor_group @@ -395,15 +396,15 @@ def __post_init__(self): object.__setattr__(self, "value", jax.device_put(self.value)) - def __getitem__(self, name: Any): + def __getitem__(self, name: Any) -> jnp.ndarray: """Function to query log potentials for a named factor group or a factor. Args: name: Name of a named factor group, or a frozenset containing the set of connected variables for the queried factor. - Returned: - The quried log potentials. + Returns: + The queried log potentials. """ if not isinstance(name, Hashable): name = frozenset(name) @@ -565,7 +566,7 @@ def __getitem__(self, names: Tuple[Any, Any]) -> jnp.ndarray: factor.edges_num_states[: factor.variables.index(variable)] ) msgs = jax.device_put(self.value)[start : start + variable.num_states] - return jax.device_put(msgs) + return msgs @typing.overload def __setitem__( @@ -743,7 +744,7 @@ def tree_unflatten(cls, aux_data, children): return cls(**aux_data.unflatten(children)) -def BP(bp_state: BPState, num_iters: int): +def BP(bp_state: BPState, num_iters: int) -> Tuple[Callable, Callable, Callable]: """Function for generating belief propagation functions. Args: @@ -805,7 +806,7 @@ def run_bp( ftov_msgs, wiring.edges_num_states, max_msg_size ) - def update(msgs, _): + def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]: # Compute new variable to factor messages by message passing vtof_msgs = infer.pass_var_to_fac_messages( msgs, @@ -858,8 +859,8 @@ def get_bp_state(bp_arrays: BPArrays) -> BPState: ) @jax.jit - def get_beliefs(bp_arrays: BPArrays): - """Calculate beliefs from a given BPArrays + def get_beliefs(bp_arrays: BPArrays) -> Any: + """Calculate beliefs from given BPArrays Args: bp_arrays: A BPArrays containing arrays for belief propagation. @@ -878,7 +879,7 @@ def get_beliefs(bp_arrays: BPArrays): @jax.jit -def decode_map_states(beliefs: Any): +def decode_map_states(beliefs: Any) -> Any: """Function to decode MAP states given the calculated beliefs. Args: From ba55c88eb1eed7c8898b7f087bfe28cf8dc67001 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 31 Oct 2021 13:21:46 -0700 Subject: [PATCH 52/56] Add raises --- pgmax/fg/graph.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 70e7a49e..36e84109 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -143,6 +143,10 @@ def _register_factor_group( Args: factor_group: The factor group to be registered to the factor graph. name: Optional name of the factor group. + + Raises: + ValueError: If the factor group with the same name or a factor involving the same variables + already exists in the factor graph. """ if name in self._named_factor_groups: raise ValueError( @@ -302,6 +306,10 @@ class BPState: log_potentials: log potentials of the model ftov_msgs: factor to variable messages evidence: evidence (unary log potentials) for variables. + + Raises: + ValueError: If log_potentials, ftov_msgs or evidence are not derived from the same + FactorGraphState. """ log_potentials: LogPotentials @@ -336,6 +344,10 @@ def update_log_potentials( Returns: A flat jnp array containing updated log_potentials. + + Raises: ValueError if + (1) Provided log_potentials shape does not match the expected log_potentials shape. + (2) Provided name is not valid for log_potentials updates. """ for name in updates: data = updates[name] @@ -377,6 +389,9 @@ class LogPotentials: Args: fg_state: Factor graph state value: Optionally specify an initial value + + Raises: + ValueError: If provided value shape does not match the expected log_potentials shape. """ fg_state: FactorGraphState @@ -464,6 +479,10 @@ def update_ftov_msgs( Returns: A flat jnp array containing updated ftov_msgs. + + Raises: ValueError if: + (1) provided ftov_msgs shape does not match the expected ftov_msgs shape. + (2) provided name is not valid for ftov_msgs updates. """ for names in updates: data = updates[names] @@ -520,6 +539,9 @@ class FToVMessages: Args: fg_state: Factor graph state value: Optionally specify initial value for ftov messages + + Raises: ValueError if provided value does not match expected ftov messages + shape. """ fg_state: FactorGraphState @@ -549,6 +571,8 @@ def __getitem__(self, names: Tuple[Any, Any]) -> jnp.ndarray: Returns: An array containing the current ftov messages from factor names[0] to variable names[1] + + Raises: ValueError if provided names are not valid for querying ftov messages. """ if not ( isinstance(names, tuple) @@ -661,6 +685,8 @@ class Evidence: Args: fg_state: Factor graph state value: Optionally specify initial value for evidence + + Raises: ValueError if provided value does not match expected evidence shape. """ fg_state: FactorGraphState From d0f6ad93dd379c3466a877790ec00f89e026d5df Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 31 Oct 2021 13:26:27 -0700 Subject: [PATCH 53/56] Fix mypy error --- examples/sanity_check_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index fa300df5..e4ca74b6 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -333,7 +333,7 @@ def create_valid_suppression_config_arr(suppression_diameter): bp_state = fg.bp_state bp_state.evidence["grid_vars"] = grid_evidence_arr bp_state.evidence["additional_vars"] = additional_vars_evidence_dict -run_bp, _, get_beliefs = graph.BP(bp_state, 1000) +run_bp, get_bp_state, get_beliefs = graph.BP(bp_state, 1000) bp_start_time = timer() bp_arrays = run_bp() bp_end_time = timer() From b454b9511a25e6631da5d45dc831a280c9881634 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 31 Oct 2021 20:27:14 -0700 Subject: [PATCH 54/56] Raises --- pgmax/fg/groups.py | 58 ++++++++++++++++++++++++++++++++++++++++++---- pgmax/fg/nodes.py | 13 ++++------- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 509fc23c..38658b0d 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -64,6 +64,9 @@ def __getitem__(self, name): Returns: A single variable if the name is not a list. A list of variables if name is a list + + Raises: + ValueError: if the name is not found in the group """ if isinstance(name, List): @@ -196,6 +199,9 @@ def __getitem__(self, name): Returns: A single variable if the name is not a list. A list of variables if name is a list + + Raises: + ValueError: if the name does not have the right format (tuples with at least two elements). """ if isinstance(name, List): names_list = name @@ -277,6 +283,11 @@ def unflatten( Returns: Meaningful structured data, with structure matching that of self.variable_group_container. + + Raises: + ValueError if: + (1) flat_data is not a 1D array + (2) flat_data is not of the right shape """ if flat_data.ndim != 1: raise ValueError( @@ -378,6 +389,9 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: Returns: A flat jnp.array for internal use + + Raises: + ValueError: If the data is not of the correct shape. """ if data.shape != self.shape and data.shape != self.shape + ( self.variable_size, @@ -400,6 +414,11 @@ def unflatten( Returns: Meaningful structured data. An array of shape self.shape (for e.g. MAP decodings) or an array of shape self.shape + (self.variable_size,) (for e.g. evidence, beliefs). + + Raises: + ValueError if: + (1) flat_data is not a 1D array + (2) flat_data is not of the right shape """ if flat_data.ndim != 1: raise ValueError( @@ -458,6 +477,11 @@ def flatten( Returns: A flat jnp.array for internal use + + Raises: + ValueError if: + (1) data is referring to a non-existing variable + (2) data is not of the correct shape """ for name in data: if name not in self._names_to_variables: @@ -487,6 +511,10 @@ def unflatten( Each value should be an array of shape (1,) (for e.g. MAP decodings) or (self.variable_size,) (for e.g. evidence, beliefs). + Raises: + ValueError if: + (1) flat_data is not a 1D array + (2) flat_data is not of the right shape """ if flat_data.ndim != 1: raise ValueError( @@ -560,6 +588,9 @@ def __getitem__( Returns: A queried individual factor + + Raises: + ValueError: if the queried factor is not present in the factor group """ variables = frozenset(variables) if variables not in self._variables_to_factors: @@ -588,7 +619,7 @@ def _get_variables_to_factors( a dictionary mapping all possible names to different factors. """ raise NotImplementedError( - "Please subclass the VariableGroup class and override this method" + "Please subclass the FactorGroup class and override this method" ) @cached_property @@ -662,6 +693,9 @@ def _get_variables_to_factors( Returns: a dictionary mapping all possible set of connected variables to different factors. + + Raises: + ValueError: if the specified log_potentials is not of the right shape """ num_factors = len(self.connected_variable_names) num_val_configs = self.factor_configs.shape[0] @@ -708,6 +742,9 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: Returns: A flat jnp.array for internal use + + Raises: + ValueError: if data is not of the right shape. """ num_factors = len(self.factors) if ( @@ -744,6 +781,11 @@ def unflatten( Meaningful structured data. Should be an array of shape (num_val_configs,) (for shared log potentials) or (num_factors, num_val_configs) (for log potentials) or (num_factors, num_edge_states) (for ftov messages). + + Raises: + ValueError if: + (1) flat_data is not a 1D array + (2) flat_data is not of the right shape """ if flat_data.ndim != 1: raise ValueError( @@ -799,9 +841,12 @@ def _get_variables_to_factors( a dictionary mapping all possible set of connected variables to different factors. Raises: - ValueError: if every sub-list within self.connected_variable_names has len != 2, or if the shape of the - log_potential_matrix is not the same as the variable sizes for each variable referenced in - each sub-list of self.connected_variable_names + ValueError if: + (1) The specified log_potential_matrix is not a 2D or 3D array. + (2) Some pairwise factors connect to less or more than 2 variables. + (3) The specified log_potential_matrix does not match the number of factors. + (4) The specified log_potential_matrix does not match the number of variable states of the + variables in the factors. """ if self.log_potential_matrix is None: log_potential_matrix = np.zeros( @@ -925,6 +970,11 @@ def unflatten( (num_factors, var0_num_states, var1_num_states) (for log potential matrices) or (num_factors, var0_num_states + var1_num_states) (for ftov messages) or (var0_num_states, var1_num_states) (for shared log potential matrix). + + Raises: + ValueError if: + (1) flat_data is not a 1D array + (2) flat_data is not of the right shape """ if flat_data.ndim != 1: raise ValueError( diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 30c2e873..dfa0bb50 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -70,14 +70,11 @@ class EnumerationFactor: Raises: ValueError: If: - (1) the dtype of the configs array is not int - (2) the dtype of the potential array is not float - (3) configs array doesn't have the same number of columns - as there are variables - (4) the potential array doesn't have the same number of - rows as the configs array - (5) any value in the configs array is greater than the size - of the corresponding variable or less than 0. + (1) The dtype of the configs array is not int + (2) The dtype of the potential array is not float + (3) Configs does not have the correct shape + (4) The potential array does not have the correct shape + (5) The configs array contains invalid values """ variables: Tuple[Variable, ...] From 7c4690a381da78cdfbc54b2004c223cae4845c00 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 31 Oct 2021 20:32:07 -0700 Subject: [PATCH 55/56] Comments --- pgmax/fg/groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 38658b0d..4265ebd5 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -595,7 +595,7 @@ def __getitem__( variables = frozenset(variables) if variables not in self._variables_to_factors: raise ValueError( - f"The queried factor {variables} is not present in the factor group." + f"The queried factor connected to the set of variables {variables} is not present in the factor group." ) return self._variables_to_factors[variables] From 32d9c48356a4a86a30ae07bf6cc73d5da592ffd0 Mon Sep 17 00:00:00 2001 From: stannis Date: Sun, 31 Oct 2021 20:35:23 -0700 Subject: [PATCH 56/56] Fix tests --- tests/fg/test_groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index dc195ed4..cd7410a5 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -175,7 +175,7 @@ def test_enumeration_factor_group(): with pytest.raises( ValueError, match=re.escape( - f"The queried factor {frozenset(name)} is not present in the factor group." + f"The queried factor connected to the set of variables {frozenset(name)} is not present in the factor group." ), ): enumeration_factor_group[name]