diff --git a/tensorflow_probability/python/experimental/bijectors/highway_flow.py b/tensorflow_probability/python/experimental/bijectors/highway_flow.py index 4c25aee1fb..c74f4bcf58 100644 --- a/tensorflow_probability/python/experimental/bijectors/highway_flow.py +++ b/tensorflow_probability/python/experimental/bijectors/highway_flow.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.internal import cache_util from tensorflow_probability.python.internal import samplers - def build_highway_flow_layer(width, residual_fraction_initial_value=0.5, activation_fn=False, @@ -363,3 +362,243 @@ def _inverse_log_det_jacobian(self, y): _, attrs = self._augmented_inverse(y) cached.update(attrs) return cached['ildj'] +======= +def build_highway_flow_layer(width, residual_fraction_initial_value=0.5, + activation_fn=False, seed=None): + # TODO: add control that residual_fraction_initial_value is between 0 and 1 + residual_fraction_initial_value = tf.convert_to_tensor( + residual_fraction_initial_value, + dtype_hint=tf.float32, + name='residual_fraction_initial_value') + dtype = residual_fraction_initial_value.dtype + + bias_seed, upper_seed, lower_seed, diagonal_seed = samplers.split_seed(seed, + n=4) + return HighwayFlow( + residual_fraction=util.TransformedVariable( + initial_value=residual_fraction_initial_value, + bijector=tfb.Sigmoid(), + dtype=dtype), + activation_fn=activation_fn, + bias=tf.Variable( + samplers.normal((width,), mean=0., stddev=0.01, seed=bias_seed), + dtype=dtype), + upper_diagonal_weights_matrix=util.TransformedVariable( + initial_value=tf.experimental.numpy.tril( + samplers.normal((width, width), mean=0., stddev=1., + seed=upper_seed), + k=-1) + tf.linalg.diag( + samplers.uniform((width,), minval=0., maxval=1., + seed=diagonal_seed)), + bijector=tfb.FillScaleTriL(diag_bijector=tfb.Softplus(), + diag_shift=None), + dtype=dtype), + lower_diagonal_weights_matrix=util.TransformedVariable( + initial_value=samplers.normal((width, width), mean=0., stddev=1., + seed=lower_seed), + bijector=tfb.Chain( + [tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)), + tfb.Pad(paddings=[(1, 0), (0, 1)]), + tfb.FillTriangular()]), + dtype=dtype) + ) + + +class HighwayFlow(tfb.Bijector): + """Implements an Highway Flow bijector [1], which interpolates the input + `X` with the transformations at each step of the bjiector. + The Highway Flow can be used as building block for a Cascading flow [1] + or as a generic normalizing flow. + + The transformation consists in a convex update between the input `X` and a + linear transformation of `X` followed by activation with the form `g(A @ + X + b)`, where `g(.)` is a differentiable non-decreasing activation + function, and `A` and `b` are trainable weights. + + The convex update is regulated by a trainable residual fraction `l` + constrained between 0 and 1, and can be + formalized as: + `Y = l * X + (1 - l) * g(A @ X + b)`. + + To make this transformation invertible, the bijector is split in three + convex updates: + - `Y1 = l * X + (1 - l) * L @ X`, with `L` lower diagonal matrix with ones + on the diagonal; + - `Y2 = l * Y1 + (1 - l) * (U @ Y1 + b)`, with `U` upper diagonal matrix + with positive diagonal; + - `Y = l * Y2 + (1 - l) * g(Y2)` + + The function `build_highway_flow_layer` helps initializing the bijector + with the variables respecting the various constraints. + + For more details on Highway Flow and Cascading Flows see [1]. + + #### Usage example: + ```python + tfd = tfp.distributions + tfb = tfp.bijectors + + dim = 4 # last input dimension + + bijector = build_highway_flow_layer(dim, activation_fn=True) + y = bijector.forward(x) # forward mapping + x = bijector.inverse(y) # inverse mapping + base = tfd.MultivariateNormalDiag(loc=tf.zeros(dim)) # Base distribution + transformed_distribution = tfd.TransformedDistribution(base, bijector) + ``` + + #### References + + [1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven. + "Automatic variational inference with + cascading flows." arXiv preprint arXiv:2102.04801 (2021). + """ + + # HighWay Flow simultaneously computes `forward` and `fldj` + # (and `inverse`/`ildj`), so we override the bijector cache to update the + # LDJ entries of attrs on forward/inverse inverse calls (instead of + # updating them only when the LDJ methods themselves are called). + + _cache = cache_util.BijectorCacheWithGreedyAttrs( + forward_name='_augmented_forward', + inverse_name='_augmented_inverse') + + def __init__(self, residual_fraction, activation_fn, bias, + upper_diagonal_weights_matrix, + lower_diagonal_weights_matrix, validate_args=False, + name='highway_flow'): + ''' + Args: + residual_fraction: scalar `Tensor` used for the convex update, + must be + between 0 and 1 + activation_fn: bool to decide whether to use softplus (True) + activation or no activation (False) + bias: bias vector + upper_diagonal_weights_matrix: Lower diagional matrix of size + (width, width) with positive diagonal + (is transposed to Upper diagonal within the bijector) + lower_diagonal_weights_matrix: Lower diagonal matrix with ones on + the main diagional. + ''' + parameters = dict(locals()) + with tf.name_scope(name) as name: + self._width = tf.shape(bias)[-1] + self._bias = bias + self._residual_fraction = residual_fraction + # The upper matrix is still lower triangular, transpose is done in + # _inverse and _forwars metowds, within matvec. + self._upper_diagonal_weights_matrix = upper_diagonal_weights_matrix + self._lower_diagonal_weights_matrix = lower_diagonal_weights_matrix + self._activation_fn = activation_fn + + super(HighwayFlow, self).__init__( + validate_args=validate_args, + forward_min_event_ndims=1, + parameters=parameters, + name=name) + + @property + def bias(self): + return self._bias + + @property + def width(self): + return self._width + + @property + def residual_fraction(self): + return self._residual_fraction + + @property + def upper_diagonal_weights_matrix(self): + return self._upper_diagonal_weights_matrix + + @property + def lower_diagonal_weights_matrix(self): + return self._lower_diagonal_weights_matrix + + @property + def activation_fn(self): + return self._activation_fn + + def _derivative_of_sigmoid(self, x): + return self.residual_fraction + ( + 1. - self.residual_fraction) * tf.math.sigmoid(x) + + def _convex_update(self, weights_matrix): + return self.residual_fraction * tf.eye(self.width) + ( + 1. - self.residual_fraction) * weights_matrix + + def _inverse_of_sigmoid(self, y, N=20): + # Inverse of the activation layer with softplus using Newton iteration. + x = tf.ones(y.shape) + for _ in range(N): + x = x - (self.residual_fraction * x + ( + 1. - self.residual_fraction) * tf.math.softplus( + x) - y) / ( + self._derivative_of_sigmoid(x)) + return x + + def _augmented_forward(self, x): + # Log determinant term from the upper matrix. Note that the log determinant + # of the lower matrix is zero. + fldj = tf.zeros(x.shape[:-1]) + tf.reduce_sum( + tf.math.log(self.residual_fraction + ( + 1. - self.residual_fraction) * tf.linalg.diag_part( + self.upper_diagonal_weights_matrix))) + x = tf.linalg.matvec( + self._convex_update(self.lower_diagonal_weights_matrix), x) + x = tf.linalg.matvec(tf.transpose( + self._convex_update(self.upper_diagonal_weights_matrix)), + x) + ( + 1 - self.residual_fraction) * self.bias + if self.activation_fn: + fldj += tf.reduce_sum(tf.math.log(self._derivative_of_sigmoid(x)), + -1) + x = self.residual_fraction * x + ( + 1. - self.residual_fraction) * self.activation_fn(x) + return x, {'ildj': -fldj, 'fldj': fldj} + + def _augmented_inverse(self, y): + ildj = tf.zeros(y.shape[:-1]) - tf.reduce_sum( + tf.math.log(self.residual_fraction + ( + 1. - self.residual_fraction) * tf.linalg.diag_part( + self.upper_diagonal_weights_matrix))) + if self.activation_fn: + y = self._inverse_of_sigmoid(y) + ildj -= tf.reduce_sum(tf.math.log(self._derivative_of_sigmoid(y)), + -1) + + y = tf.linalg.triangular_solve(tf.transpose( + self._convex_update(self.upper_diagonal_weights_matrix)), + tf.linalg.matrix_transpose(y - ( + 1 - self.residual_fraction) * self.bias), + lower=False) + y = tf.linalg.triangular_solve( + self._convex_update(self.lower_diagonal_weights_matrix), y) + return tf.linalg.matrix_transpose(y), {'ildj': ildj, 'fldj': -ildj} + + def _forward(self, x): + y, _ = self._augmented_forward(x) + return y + + def _inverse(self, y): + x, _ = self._augmented_inverse(y) + return x + + def _forward_log_det_jacobian(self, x): + cached = self._cache.forward_attributes(x) + # If LDJ isn't in the cache, call forward once. + if 'fldj' not in cached: + _, attrs = self._augmented_forward(x) + cached.update(attrs) + return cached['fldj'] + + def _inverse_log_det_jacobian(self, y): + cached = self._cache.inverse_attributes(y) + # If LDJ isn't in the cache, call inverse once. + if 'ildj' not in cached: + _, attrs = self._augmented_inverse(y) + cached.update(attrs) + return cached['ildj'] \ No newline at end of file diff --git a/tensorflow_probability/python/experimental/vi/cascading_flows.py b/tensorflow_probability/python/experimental/vi/cascading_flows.py deleted file mode 100644 index 61dcce7236..0000000000 --- a/tensorflow_probability/python/experimental/vi/cascading_flows.py +++ /dev/null @@ -1,542 +0,0 @@ -# Copyright 2021 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Utilities for constructing structured surrogate posteriors.""" - -from __future__ import absolute_import -from __future__ import division -# [internal] enable type annotations -from __future__ import print_function - -import copy -import functools -import inspect - -import tensorflow.compat.v2 as tf - -from tensorflow_probability.python.bijectors import chain -from tensorflow_probability.python.bijectors import reshape -from tensorflow_probability.python.bijectors import scale as scale_lib -from tensorflow_probability.python.bijectors import shift -from tensorflow_probability.python.bijectors import split -from tensorflow_probability.python.distributions import batch_broadcast -from tensorflow_probability.python.distributions import beta -from tensorflow_probability.python.distributions import blockwise -from tensorflow_probability.python.distributions import chi2 -from tensorflow_probability.python.distributions import deterministic -from tensorflow_probability.python.distributions import exponential -from tensorflow_probability.python.distributions import gamma -from tensorflow_probability.python.distributions import half_normal -from tensorflow_probability.python.distributions import independent -from tensorflow_probability.python.distributions import \ - joint_distribution_auto_batched -from tensorflow_probability.python.distributions import \ - joint_distribution_coroutine -from tensorflow_probability.python.distributions import normal -from tensorflow_probability.python.distributions import sample -from tensorflow_probability.python.distributions import transformed_distribution -from tensorflow_probability.python.distributions import truncated_normal -from tensorflow_probability.python.distributions import uniform -from tensorflow_probability.python.experimental.bijectors import \ - build_highway_flow_layer -from tensorflow_probability.python.internal import samplers - -__all__ = [ - 'register_cf_substitution_rule', - 'build_cf_surrogate_posterior' -] - -Root = joint_distribution_coroutine.JointDistributionCoroutine.Root - -_NON_STATISTICAL_PARAMS = [ - 'name', 'validate_args', 'allow_nan_stats', 'experimental_use_kahan_sum', - 'reinterpreted_batch_ndims', 'dtype', 'force_probs_to_zero_outside_support', - 'num_probit_terms_approx' -] -_NON_TRAINABLE_PARAMS = ['low', 'high'] - -# Registry of transformations that are applied to distributions in the prior -# before defining the surrogate family. - - -# Todo: inherited from asvi code, do we need this? -ASVI_SURROGATE_SUBSTITUTIONS = {} - - -# Todo: inherited from asvi code, do we need this? -def _as_substituted_distribution(distribution): - """Applies all substitution rules that match a distribution.""" - for condition, substitution_fn in ASVI_SURROGATE_SUBSTITUTIONS.items(): - if condition(distribution): - distribution = substitution_fn(distribution) - return distribution - - -# Todo: inherited from asvi code, do we need this? -def register_cf_substitution_rule(condition, substitution_fn): - """Registers a rule for substituting distributions in ASVI surrogates. - - Args: - condition: Python `callable` that takes a Distribution instance and - returns a Python `bool` indicating whether or not to substitute it. - May also be a class type such as `tfd.Normal`, in which case the - condition is interpreted as - `lambda distribution: isinstance(distribution, class)`. - substitution_fn: Python `callable` that takes a Distribution - instance and returns a new Distribution instance used to define - the ASVI surrogate posterior. Note that this substitution does not modify - the original model. - - #### Example - - To use a Normal surrogate for all location-scale family distributions, we - could register the substitution: - - ```python - tfp.experimental.vi.register_asvi_surrogate_substitution( - condition=lambda distribution: ( - hasattr(distribution, 'loc') and hasattr(distribution, 'scale')) - substitution_fn=lambda distribution: ( - # Invoking the event space bijector applies any relevant constraints, - # e.g., that HalfCauchy samples must be `>= loc`. - distribution.experimental_default_event_space_bijector()( - tfd.Normal(loc=distribution.loc, scale=distribution.scale))) - ``` - - This rule will fire when ASVI encounters a location-scale distribution, - and instructs ASVI to build a surrogate 'as if' the model had just used a - (possibly constrained) Normal in its place. Note that we could have used a - more precise condition, e.g., to limit the substitution to distributions with - a specific `name`, if we had reason to think that a Normal distribution would - be a good surrogate for some model variables but not others. - - """ - global ASVI_SURROGATE_SUBSTITUTIONS - if inspect.isclass(condition): - condition = lambda distribution, cls=condition: isinstance( - # pylint: disable=g-long-lambda - distribution, cls) - ASVI_SURROGATE_SUBSTITUTIONS[condition] = substitution_fn - - -# Default substitutions attempt to express distributions using the most -# flexible available parameterization. -# pylint: disable=g-long-lambda -register_cf_substitution_rule( - half_normal.HalfNormal, - lambda dist: truncated_normal.TruncatedNormal( - loc=0., scale=dist.scale, low=0., high=dist.scale * 10.)) -register_cf_substitution_rule( - uniform.Uniform, - lambda dist: shift.Shift(dist.low)( - scale_lib.Scale(dist.high - dist.low)( - beta.Beta(concentration0=tf.ones_like(dist.mean()), - concentration1=1.)))) -register_cf_substitution_rule( - exponential.Exponential, - lambda dist: gamma.Gamma(concentration=1., rate=dist.rate)) -register_cf_substitution_rule( - chi2.Chi2, - lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5)) - - -# pylint: enable=g-long-lambda - -# a single JointDistribution. -def build_cf_surrogate_posterior( - prior, - num_auxiliary_variables=0, - initial_prior_weight=0.5, - seed=None, - name=None): - # todo: change docstrings - """Builds a structured surrogate posterior inspired by conjugate updating. - - ASVI, or Automatic Structured Variational Inference, was proposed by - Ambrogioni et al. (2020) [1] as a method of automatically constructing a - surrogate posterior with the same structure as the prior. It does this by - reparameterizing the variational family of the surrogate posterior by - structuring each parameter according to the equation - ```none - prior_weight * prior_parameter + (1 - prior_weight) * mean_field_parameter - ``` - In this equation, `prior_parameter` is a vector of prior parameters and - `mean_field_parameter` is a vector of trainable parameters with the same - domain as `prior_parameter`. `prior_weight` is a vector of learnable - parameters where `0. <= prior_weight <= 1.`. When `prior_weight = - 0`, the surrogate posterior will be a mean-field surrogate, and when - `prior_weight = 1.`, the surrogate posterior will be the prior. This convex - combination equation, inspired by conjugacy in exponential families, thus - allows the surrogate posterior to balance between the structure of the prior - and the structure of a mean-field approximation. - - Args: - prior: tfd.JointDistribution instance of the prior. - mean_field: Optional Python boolean. If `True`, creates a degenerate - surrogate distribution in which all variables are independent, - ignoring the prior dependence structure. Default value: `False`. - initial_prior_weight: Optional float value (either static or tensor value) - on the interval [0, 1]. A larger value creates an initial surrogate - distribution with more dependence on the prior structure. Default value: - `0.5`. - seed: Python `int` seed for random initialization. - name: Optional string. Default value: `build_cf_surrogate_posterior`. - - Returns: - surrogate_posterior: A `tfd.JointDistributionCoroutineAutoBatched` instance - whose samples have shape and structure matching that of `prior`. - - Raises: - TypeError: The `prior` argument cannot be a nested `JointDistribution`. - - ### Examples - - Consider a Brownian motion model expressed as a JointDistribution: - - ```python - prior_loc = 0. - innovation_noise = .1 - - def model_fn(): - new = yield tfd.Normal(loc=prior_loc, scale=innovation_noise) - for i in range(4): - new = yield tfd.Normal(loc=new, scale=innovation_noise) - - prior = tfd.JointDistributionCoroutineAutoBatched(model_fn) - ``` - - Let's use variational inference to approximate the posterior. We'll build a - surrogate posterior distribution by feeding in the prior distribution. - - ```python - surrogate_posterior = - tfp.experimental.vi.build_cf_surrogate_posterior(prior) - ``` - - This creates a trainable joint distribution, defined by variables in - `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` - to fit this distribution by minimizing a divergence to the true posterior. - - ```python - losses = tfp.vi.fit_surrogate_posterior( - target_log_prob_fn, - surrogate_posterior=surrogate_posterior, - num_steps=100, - optimizer=tf.optimizers.Adam(0.1), - sample_size=10) - - # After optimization, samples from the surrogate will approximate - # samples from the true posterior. - samples = surrogate_posterior.sample(100) - posterior_mean = [tf.reduce_mean(x) for x in samples] - posterior_std = [tf.math.reduce_std(x) for x in samples] - ``` - - #### References - [1]: Luca Ambrogioni, Max Hinne, Marcel van Gerven. Automatic structured - variational inference. _arXiv preprint arXiv:2002.00643_, 2020 - https://arxiv.org/abs/2002.00643 - - """ - with tf.name_scope(name or 'build_cf_surrogate_posterior'): - surrogate_posterior, variables = _cf_surrogate_for_distribution( - dist=prior, - base_distribution_surrogate_fn=functools.partial( - _cf_convex_update_for_base_distribution, - initial_prior_weight=initial_prior_weight, - num_auxiliary_variables=num_auxiliary_variables), - num_auxiliary_variables=num_auxiliary_variables, - seed=seed) - surrogate_posterior.also_track = variables - return surrogate_posterior - - -def _cf_surrogate_for_distribution(dist, - base_distribution_surrogate_fn, - sample_shape=None, - variables=None, - num_auxiliary_variables=0, - global_auxiliary_variables=None, - seed=None): - # todo: change docstrings - """Recursively creates ASVI surrogates, and creates new variables if needed. - - Args: - dist: a `tfd.Distribution` instance. - base_distribution_surrogate_fn: Callable to build a surrogate posterior - for a 'base' (non-meta and non-joint) distribution, with signature - `surrogate_posterior, variables = base_distribution_fn( - dist, sample_shape=None, variables=None, seed=None)`. - sample_shape: Optional `Tensor` shape of samples drawn from `dist` by - `tfd.Sample` wrappers. If not `None`, the surrogate's event will include - independent sample dimensions, i.e., it will have event shape - `concat([sample_shape, dist.event_shape], axis=0)`. - Default value: `None`. - variables: Optional nested structure of `tf.Variable`s returned from a - previous call to `_cf_surrogate_for_distribution`. If `None`, - new variables will be created; otherwise, constructs a surrogate posterior - backed by the passed-in variables. - Default value: `None`. - seed: Python `int` seed for random initialization. - Returns: - surrogate_posterior: Instance of `tfd.Distribution` representing a trainable - surrogate posterior distribution, with the same structure and `name` as - `dist`. - variables: Nested structure of `tf.Variable` trainable parameters for the - surrogate posterior. If `dist` is a base distribution, this is - a `dict` of `ASVIParameters` instances. If `dist` is a joint - distribution, this is a `dist.dtype` structure of such `dict`s. - """ - - # Apply any substitutions, while attempting to preserve the original name. - dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist)) - - if hasattr(dist, '_model_coroutine'): - surrogate_posterior, variables = _cf_surrogate_for_joint_distribution( - dist, - base_distribution_surrogate_fn=base_distribution_surrogate_fn, - variables=variables, - num_auxiliary_variables=num_auxiliary_variables, - global_auxiliary_variables=global_auxiliary_variables, - seed=seed) - else: - surrogate_posterior, variables = base_distribution_surrogate_fn( - dist=dist, sample_shape=sample_shape, variables=variables, - global_auxiliary_variables=global_auxiliary_variables, seed=seed) - return surrogate_posterior, variables - - -def _cf_surrogate_for_joint_distribution( - dist, base_distribution_surrogate_fn, variables=None, - num_auxiliary_variables=0, global_auxiliary_variables=None, seed=None): - """Builds a structured joint surrogate posterior for a joint model.""" - - # Probabilistic program for CF surrogate posterior. - flat_variables = dist._model_flatten( - variables) if variables else None # pylint: disable=protected-access - prior_coroutine = dist._model_coroutine # pylint: disable=protected-access - - def posterior_generator(seed=seed): - prior_gen = prior_coroutine() - dist = next(prior_gen) - - if num_auxiliary_variables > 0: - i = 1 - - if flat_variables: - variables = flat_variables[0] - - else: - layers = 3 - bijectors = [] - - for _ in range(0, layers - 1): - bijectors.append( - build_highway_flow_layer(num_auxiliary_variables, - residual_fraction_initial_value=0.5, - activation_fn=True, gate_first_n=0, - seed=seed)) - bijectors.append( - build_highway_flow_layer(num_auxiliary_variables, - residual_fraction_initial_value=0.5, - activation_fn=False, gate_first_n=0, - seed=seed)) - - variables = chain.Chain(bijectors=list(reversed(bijectors))) - - eps = transformed_distribution.TransformedDistribution( - distribution=sample.Sample(normal.Normal(0., 0.1), - num_auxiliary_variables), - bijector=variables) - - eps = Root(eps) - - value_out = yield (eps if flat_variables - else (eps, variables)) - - global_auxiliary_variables = value_out - - else: - i = 0 - - try: - while True: - was_root = isinstance(dist, Root) - if was_root: - dist = dist.distribution - - seed, init_seed = samplers.split_seed(seed) - surrogate_posterior, variables = _cf_surrogate_for_distribution( - dist, - base_distribution_surrogate_fn=base_distribution_surrogate_fn, - variables=flat_variables[i] if flat_variables else None, - global_auxiliary_variables=global_auxiliary_variables, - seed=init_seed) - - if was_root and num_auxiliary_variables == 0: - surrogate_posterior = Root(surrogate_posterior) - # If variables were not given---i.e., we're creating new - # variables---then yield the new variables along with the surrogate - # posterior. This assumes an execution context such as - # `_extract_variables_from_coroutine_model` below that will capture and - # save the variables. - value_out = yield (surrogate_posterior if flat_variables - else (surrogate_posterior, variables)) - if type(value_out) == list: - if len(dist.event_shape) == 0: - dist = prior_gen.send(tf.squeeze(value_out[0], -1)) - else: - dist = prior_gen.send(value_out[0]) - - else: - dist = prior_gen.send(value_out) - i += 1 - except StopIteration: - pass - - if variables is None: - # Run the generator to create variables, then call ourselves again - # to construct the surrogate JD from these variables. Note that we can't - # just create a JDC from the current `posterior_generator`, because it will - # try to build new variables on every invocation; the recursive call will - # define a new `posterior_generator` that knows about the variables we're - # about to create. - return _cf_surrogate_for_joint_distribution( - dist=dist, - base_distribution_surrogate_fn=base_distribution_surrogate_fn, - num_auxiliary_variables=num_auxiliary_variables, - global_auxiliary_variables=global_auxiliary_variables, - variables=dist._model_unflatten( # pylint: disable=protected-access - _extract_variables_from_coroutine_model( - posterior_generator, seed=seed))) - - # Temporary workaround for bijector caching issues with autobatched JDs. - surrogate_type = joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched - if not hasattr(dist, 'use_vectorized_map'): - surrogate_type = joint_distribution_coroutine.JointDistributionCoroutine - surrogate_posterior = surrogate_type(posterior_generator, - name=_get_name(dist)) - - # Ensure that the surrogate posterior structure matches that of the prior. - # todo: check me, do we need this? in case needs to be modified - # if we use auxiliary variables, then the structure won't match the one of the - # prior - '''try: - tf.nest.assert_same_structure(dist.dtype, surrogate_posterior.dtype) - except TypeError: - tokenize = lambda jd: jd._model_unflatten( - # pylint: disable=protected-access, g-long-lambda - range(len(jd._model_flatten(jd.dtype))) - # pylint: disable=protected-access - ) - surrogate_posterior = restructure.Restructure( - output_structure=tokenize(dist), - input_structure=tokenize(surrogate_posterior))( - surrogate_posterior, name=_get_name(dist))''' - return surrogate_posterior, variables - - -# todo: sample_shape and seed are not used.. maybe they should? -def _cf_convex_update_for_base_distribution(dist, - initial_prior_weight, - num_auxiliary_variables=0, - global_auxiliary_variables=None, - sample_shape=None, - variables=None, - seed=None): - """Creates a trainable surrogate for a (non-meta, non-joint) distribution.""" - - if variables is None: - actual_event_shape = dist.event_shape_tensor() - int_event_shape = int(actual_event_shape) if \ - actual_event_shape.shape.as_list()[0] > 0 else 1 - layers = 3 - bijectors = [reshape.Reshape([-1], - event_shape_in=actual_event_shape + - num_auxiliary_variables)] - - for _ in range(0, layers - 1): - bijectors.append( - build_highway_flow_layer( - tf.reduce_prod(actual_event_shape + num_auxiliary_variables), - residual_fraction_initial_value=initial_prior_weight, - activation_fn=True, gate_first_n=int_event_shape, seed=seed)) - bijectors.append( - build_highway_flow_layer( - tf.reduce_prod(actual_event_shape + num_auxiliary_variables), - residual_fraction_initial_value=initial_prior_weight, - activation_fn=False, gate_first_n=int_event_shape, seed=seed)) - bijectors.append( - reshape.Reshape(actual_event_shape + num_auxiliary_variables)) - - variables = chain.Chain(bijectors=list(reversed(bijectors))) - - if num_auxiliary_variables > 0: - batch_shape = global_auxiliary_variables.shape[0] if len( - global_auxiliary_variables.shape) > 1 else [] - - cascading_flows = split.Split( - [-1, num_auxiliary_variables])( - transformed_distribution.TransformedDistribution( - distribution=blockwise.Blockwise([ - batch_broadcast.BatchBroadcast(dist, - to_shape=batch_shape), - independent.Independent( - deterministic.Deterministic( - global_auxiliary_variables), - reinterpreted_batch_ndims=1)]), - bijector=variables)) - - else: - cascading_flows = transformed_distribution.TransformedDistribution( - distribution=dist, - bijector=variables) - - return cascading_flows, variables - - -def _extract_variables_from_coroutine_model(model_fn, seed=None): - """Extracts variables from a generator that yields (dist, variables) pairs.""" - gen = model_fn() - try: - dist, dist_variables = next(gen) - flat_variables = [dist_variables] - while True: - seed, local_seed = samplers.split_seed(seed, n=2) - sampled_value = (dist.distribution.sample(seed=local_seed) - if isinstance(dist, Root) - else dist.sample(seed=local_seed)) - dist, dist_variables = gen.send( - sampled_value) # tf.concat(sampled_value, axis=0) - flat_variables.append(dist_variables) - except StopIteration: - pass - return flat_variables - - -def _set_name(dist, name): - """Copies a distribution-like object, replacing its name.""" - if hasattr(dist, 'copy'): - return dist.copy(name=name) - # Some distribution-like entities such as JointDistributionPinned don't - # inherit from tfd.Distribution and don't define `self.copy`. We'll try to set - # the name directly. - dist = copy.copy(dist) - dist._name = name # pylint: disable=protected-access - return dist - - -def _get_name(dist): - """Attempts to get a distribution's short name, excluding the name scope.""" - return getattr(dist, 'parameters', {}).get('name', dist.name) diff --git a/tensorflow_probability/python/experimental/vi/cascading_flows_test.py b/tensorflow_probability/python/experimental/vi/cascading_flows_test.py deleted file mode 100644 index 9c4393be24..0000000000 --- a/tensorflow_probability/python/experimental/vi/cascading_flows_test.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright 2021 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Tests for structured surrogate posteriors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -import tensorflow.compat.v1 as tf1 -import tensorflow.compat.v2 as tf -import tensorflow_probability as tfp -from tensorflow_probability.python.experimental.vi import cascading_flows -from tensorflow_probability.python.internal import prefer_static as ps -from tensorflow_probability.python.internal import test_util - - -tfb = tfp.bijectors -tfd = tfp.distributions - - -@test_util.test_all_tf_execution_regimes -class _TrainableCFSurrogate(object): - - def _expected_num_trainable_variables(self, prior_dist): - """Infers the expected number of trainable variables for a non-nested JD.""" - prior_dists = prior_dist._get_single_sample_distributions() # pylint: disable=protected-access - expected_num_trainable_variables = 0 - for original_dist in prior_dists: - try: - original_dist = original_dist.distribution - except AttributeError: - pass - dist = cascading_flows._as_substituted_distribution(original_dist) - dist_params = dist.parameters - for param, value in dist_params.items(): - if (param not in cascading_flows._NON_STATISTICAL_PARAMS - and value is not None and param not in ('low', 'high')): - # One variable each for prior_weight, mean_field_parameter. - expected_num_trainable_variables += 2 - return expected_num_trainable_variables - - def test_dims_and_gradients(self): - - prior_dist = self.make_prior_dist() - - surrogate_posterior = tfp.experimental.vi.build_cf_surrogate_posterior( - prior=prior_dist) - - # Test that the correct number of trainable variables are being tracked - self.assertLen(surrogate_posterior.trainable_variables, - self._expected_num_trainable_variables(prior_dist)) - - # Test that the sample shape is correct - three_posterior_samples = surrogate_posterior.sample( - 3, seed=test_util.test_seed(sampler_type='stateless')) - three_prior_samples = prior_dist.sample( - 3, seed=test_util.test_seed(sampler_type='stateless')) - self.assertAllEqualNested( - [s.shape for s in tf.nest.flatten(three_prior_samples)], - [s.shape for s in tf.nest.flatten(three_posterior_samples)]) - - # Test that gradients are available wrt the variational parameters. - posterior_sample = surrogate_posterior.sample( - seed=test_util.test_seed(sampler_type='stateless')) - with tf.GradientTape() as tape: - posterior_logprob = surrogate_posterior.log_prob(posterior_sample) - grad = tape.gradient(posterior_logprob, - surrogate_posterior.trainable_variables) - self.assertTrue(all(g is not None for g in grad)) - - def test_initialization_is_deterministic_following_seed(self): - prior_dist = self.make_prior_dist() - - surrogate_posterior = tfp.experimental.vi.build_cf_surrogate_posterior( - prior=prior_dist, - seed=test_util.test_seed(sampler_type='stateless')) - self.evaluate( - [v.initializer for v in surrogate_posterior.trainable_variables]) - posterior_sample = surrogate_posterior.sample( - seed=test_util.test_seed(sampler_type='stateless')) - - surrogate_posterior2 = tfp.experimental.vi.build_cf_surrogate_posterior( - prior=prior_dist, - seed=test_util.test_seed(sampler_type='stateless')) - self.evaluate( - [v.initializer for v in surrogate_posterior2.trainable_variables]) - posterior_sample2 = surrogate_posterior2.sample( - seed=test_util.test_seed(sampler_type='stateless')) - - self.assertAllEqualNested(posterior_sample, posterior_sample2) - - -@test_util.test_all_tf_execution_regimes -class CFSurrogatePosteriorTestBrownianMotion(test_util.TestCase, - _TrainableCFSurrogate): - - def make_prior_dist(self): - - def _prior_model_fn(): - innovation_noise = 0.1 - prior_loc = 0. - new = yield tfd.Normal(loc=prior_loc, scale=innovation_noise) - for _ in range(4): - new = yield tfd.Normal(loc=new, scale=innovation_noise) - - return tfd.JointDistributionCoroutineAutoBatched(_prior_model_fn) - - def make_likelihood_model(self, x, observation_noise): - - def _likelihood_model(): - for i in range(5): - yield tfd.Normal(loc=x[i], scale=observation_noise) - - return tfd.JointDistributionCoroutineAutoBatched(_likelihood_model) - - def get_observations(self, prior_dist): - observation_noise = 0.15 - ground_truth = prior_dist.sample() - likelihood = self.make_likelihood_model( - x=ground_truth, observation_noise=observation_noise) - return likelihood.sample(1) - - def get_target_log_prob(self, observations, prior_dist): - - def target_log_prob(*x): - observation_noise = 0.15 - likelihood_dist = self.make_likelihood_model( - x=x, observation_noise=observation_noise) - return likelihood_dist.log_prob(observations) + prior_dist.log_prob(x) - - return target_log_prob - - def test_fitting_surrogate_posterior(self): - - prior_dist = self.make_prior_dist() - observations = self.get_observations(prior_dist) - surrogate_posterior = tfp.experimental.vi.build_cf_surrogate_posterior( - prior=prior_dist) - target_log_prob = self.get_target_log_prob(observations, prior_dist) - - # Test vi fit surrogate posterior works - losses = tfp.vi.fit_surrogate_posterior( - target_log_prob, - surrogate_posterior, - num_steps=5, # Don't optimize to completion. - optimizer=tf.optimizers.Adam(0.1), - sample_size=10) - - # Compute posterior statistics. - with tf.control_dependencies([losses]): - posterior_samples = surrogate_posterior.sample(100) - posterior_mean = tf.nest.map_structure(tf.reduce_mean, posterior_samples) - posterior_stddev = tf.nest.map_structure(tf.math.reduce_std, - posterior_samples) - - self.evaluate(tf1.global_variables_initializer()) - _ = self.evaluate(losses) - _ = self.evaluate(posterior_mean) - _ = self.evaluate(posterior_stddev) - - -@test_util.test_all_tf_execution_regimes -class CFSurrogatePosteriorTestEightSchools(test_util.TestCase, - _TrainableCFSurrogate): - - def make_prior_dist(self): - treatment_effects = tf.constant([28, 8, -3, 7, -1, 1, 18, 12], - dtype=tf.float32) - num_schools = ps.shape(treatment_effects)[-1] - - return tfd.JointDistributionNamed({ - 'avg_effect': - tfd.Normal(loc=0., scale=10., name='avg_effect'), - 'log_stddev': - tfd.Normal(loc=5., scale=1., name='log_stddev'), - 'school_effects': - lambda log_stddev, avg_effect: ( # pylint: disable=g-long-lambda - tfd.Independent( - tfd.Normal( - loc=avg_effect[..., None] * tf.ones(num_schools), - scale=tf.exp(log_stddev[..., None]) * tf.ones( - num_schools), - name='school_effects'), - reinterpreted_batch_ndims=1)) - }) - - -@test_util.test_all_tf_execution_regimes -class CFSurrogatePosteriorTestEightSchoolsSample(test_util.TestCase, - _TrainableCFSurrogate): - - def make_prior_dist(self): - - return tfd.JointDistributionNamed({ - 'avg_effect': - tfd.Normal(loc=0., scale=10., name='avg_effect'), - 'log_stddev': - tfd.Normal(loc=5., scale=1., name='log_stddev'), - 'school_effects': - lambda log_stddev, avg_effect: ( # pylint: disable=g-long-lambda - tfd.Sample( - tfd.Normal( - loc=avg_effect[..., None], - scale=tf.exp(log_stddev[..., None]), - name='school_effects'), - sample_shape=[8])) - }) - - -@test_util.test_all_tf_execution_regimes -class CFSurrogatePosteriorTestHalfNormal(test_util.TestCase, - _TrainableCFSurrogate): - - def make_prior_dist(self): - - def _prior_model_fn(): - innovation_noise = 1. - yield tfd.HalfNormal( - scale=innovation_noise, validate_args=True, allow_nan_stats=False) - - return tfd.JointDistributionCoroutineAutoBatched(_prior_model_fn) - - -@test_util.test_all_tf_execution_regimes -class CFSurrogatePosteriorTestDiscreteLatent( - test_util.TestCase, _TrainableCFSurrogate): - - def make_prior_dist(self): - - def _prior_model_fn(): - a = yield tfd.Bernoulli(logits=0.5, name='a') - yield tfd.Normal(loc=2. * tf.cast(a, tf.float32) - 1., - scale=1., name='b') - - return tfd.JointDistributionCoroutineAutoBatched(_prior_model_fn) - - -@test_util.test_all_tf_execution_regimes -class CFSurrogatePosteriorTestNesting(test_util.TestCase, - _TrainableCFSurrogate): - - def _expected_num_trainable_variables(self, _): - # Nested distributions have total of 10 params after Exponential->Gamma - # substitution, multiplied by 2 variables per param. - return 20 - - def make_prior_dist(self): - - def nested_model(): - a = yield tfd.Sample( - tfd.Sample( - tfd.Normal(0., 1.), - sample_shape=4), - sample_shape=[2], - name='a') - b = yield tfb.Sigmoid()( - tfb.Square()( - tfd.Exponential(rate=tf.exp(a))), - name='b') - # pylint: disable=g-long-lambda - yield tfd.JointDistributionSequential( - [tfd.Laplace(loc=a, scale=b), - lambda c1: tfd.Independent( - tfd.Beta(concentration1=1., - concentration0=tf.nn.softplus(c1)), - reinterpreted_batch_ndims=1), - lambda c1, c2: tfd.JointDistributionNamed({ - 'x': tfd.Gamma(concentration=tf.nn.softplus(c1), rate=c2)}) - ], name='c') - # pylint: enable=g-long-lambda - - return tfd.JointDistributionCoroutineAutoBatched(nested_model) - - -@test_util.test_all_tf_execution_regimes -class TestCFDistributionSubstitution(test_util.TestCase): - - def test_default_substitutes_trainable_families(self): - - @tfd.JointDistributionCoroutineAutoBatched - def model(): - yield tfd.Sample( - tfd.Uniform(low=-2., high=7.), - sample_shape=[2], - name='a') - yield tfd.HalfNormal(1., name='b') - yield tfd.Exponential(rate=[1., 2.], name='c') - yield tfd.Chi2(df=3., name='d') - - surrogate = tfp.experimental.vi.build_cf_surrogate_posterior( - model) - self.assertAllEqualNested(model.event_shape, surrogate.event_shape) - - surrogate_dists, _ = surrogate.sample_distributions() - self.assertIsInstance(surrogate_dists.a, tfd.Independent) - self.assertIsInstance(surrogate_dists.a.distribution, - tfd.TransformedDistribution) - self.assertIsInstance(surrogate_dists.a.distribution.distribution, - tfd.Beta) - self.assertIsInstance(surrogate_dists.b, tfd.TruncatedNormal) - self.assertIsInstance(surrogate_dists.c, tfd.Gamma) - self.assertIsInstance(surrogate_dists.d, tfd.Gamma) - - def test_can_specify_custom_substitution(self): - - @tfd.JointDistributionCoroutineAutoBatched - def centered_horseshoe(ndims=100): - global_scale = yield tfd.HalfCauchy( - loc=0., scale=1., name='global_scale') - local_scale = yield tfd.HalfCauchy( - loc=0., scale=tf.ones([ndims]), name='local_scale') - yield tfd.Normal( - loc=0., scale=tf.sqrt(global_scale * local_scale), name='weights') - - tfp.experimental.vi.register_asvi_substitution_rule( - condition=tfd.HalfCauchy, - substitution_fn=( - lambda d: tfb.Softplus(1e-6)(tfd.Normal(loc=d.loc, scale=d.scale)))) - surrogate = tfp.experimental.vi.build_cf_surrogate_posterior( - centered_horseshoe) - self.assertAllEqualNested(centered_horseshoe.event_shape, - surrogate.event_shape) - - # If the surrogate was built with names or structure differing from the - # model, so that it had to be `tfb.Restructure`'d, then this - # sample_distributions call will fail because the surrogate isn't an - # instance of tfd.JointDistribution. - surrogate_dists, _ = surrogate.sample_distributions() - self.assertIsInstance(surrogate_dists.global_scale.distribution, - tfd.Normal) - self.assertIsInstance(surrogate_dists.local_scale.distribution, - tfd.Normal) - self.assertIsInstance(surrogate_dists.weights, tfd.Normal) - -# TODO(kateslin): Add an ASVI surrogate posterior test for gamma distributions. -# TODO(kateslin): Add an ASVI surrogate posterior test with for a model with -# missing observations. - -if __name__ == '__main__': - tf.test.main()