Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run fit_with_hmc with jit_compile activated #1831

Open
williamjamir opened this issue Aug 13, 2024 · 0 comments
Open

Run fit_with_hmc with jit_compile activated #1831

williamjamir opened this issue Aug 13, 2024 · 0 comments

Comments

@williamjamir
Copy link

williamjamir commented Aug 13, 2024

I'm trying to run the following code:

import numpy as np
import tensorflow_probability as tfp
import tensorflow as tf

time_series_with_nans = [-1.0, 1.0, np.nan, 2.4, np.nan, 5]
observed_time_series = tfp.sts.MaskedTimeSeries(
    time_series=time_series_with_nans, is_missing=tf.math.is_nan(time_series_with_nans)
)

# Build model using observed time series to set heuristic priors.
linear_trend_model = tfp.sts.LocalLinearTrend(observed_time_series=observed_time_series)
model = tfp.sts.Sum([linear_trend_model], observed_time_series=observed_time_series)

# Fit model to data
parameter_samples, _ = tf.function(
    func=lambda ots: tfp.sts.fit_with_hmc(model, ots), jit_compile=True, autograph=False
)(observed_time_series)

Using JIT as suggested here on this comment: #1704 (comment) gives me the following error:

parameter_samples, _ = tf.function(
        func=lambda ots: tfp.sts.fit_with_hmc(model, ots),
        jit_compile=True,
        autograph=False)(observed_time_series)

test_jit_hmc.py:262: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.env/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py:153: in error_handler
    raise e.with_traceback(filtered_tb) from None
incrementality/prefect/flows/model_training.py:263: in <lambda>
    func=lambda ots: tfp.sts.fit_with_hmc(model, ots),
.env/lib/python3.11/site-packages/tensorflow_probability/python/sts/fitting.py:466: in fit_with_hmc
    variational_posterior = build_factored_surrogate_posterior(
.env/lib/python3.11/site-packages/tensorflow_probability/python/sts/fitting.py:173: in build_factored_surrogate_posterior
    return experimental_vi.build_factored_surrogate_posterior(
.env/lib/python3.11/site-packages/tensorflow_probability/python/internal/trainable_state_util.py:337: in build_stateful_trainable
    tf.nest.map_structure(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

t = <tf.Tensor 'fit_with_hmc/build_factored_surrogate_posterior/build_factored_surrogate_posterior/Normal_trainable_variables/normal/stateless_random_normal:0' shape=() dtype=float32>, n = 'loc'

>   lambda t, n=name: t if t is None else tf.Variable(t, name=n),
    value, expand_composites=True))
E   ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

.env/lib/python3.11/site-packages/tensorflow_probability/python/internal/trainable_state_util.py:338: ValueError

It looks like tfp.sts.fit_with_hmc involves creating variables as part of its execution, which raises the question:

  • How can I enable JIT for this method? Is this a limitation?

If yes, since GPU doesn't work for STS, and JAX as well (#1646 (comment)) are there any other alternatives to speed up? fit_with_hmc?

I'm using:

python 3.11

pip list | grep tensorflow
tensorflow                         2.17.0
tensorflow-probability             0.24.0

using v2.16 also produces the same error:

 pip list | grep tensor
tensorboard                        2.16.2
tensorboard-data-server            0.7.2
tensorflow                         2.16.1
tensorflow-io-gcs-filesystem       0.36.0
tensorflow-probability             0.24.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant