-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed of tfp.sts.fit_with_hmc #1704
Comments
Sorry for the sluggish response. Can you try this?
|
There will be some compilation overhead on the first run but should be overall much faster. |
@csuter How can I run fit_with_hmc with jit_compile activate? Following you example import numpy as np
import tensorflow_probability as tfp
import tensorflow as tf
time_series_with_nans = [-1.0, 1.0, np.nan, 2.4, np.nan, 5]
observed_time_series = tfp.sts.MaskedTimeSeries(
time_series=time_series_with_nans, is_missing=tf.math.is_nan(time_series_with_nans)
)
# Build model using observed time series to set heuristic priors.
linear_trend_model = tfp.sts.LocalLinearTrend(observed_time_series=observed_time_series)
model = tfp.sts.Sum([linear_trend_model], observed_time_series=observed_time_series)
# Fit model to data
parameter_samples, _ = tf.function(
func=lambda ots: tfp.sts.fit_with_hmc(model, ots), jit_compile=True, autograph=False
)(observed_time_series) Gives me the following error:
I'm using
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
I've been building some structural time-series models with Tensorflow Probability over the past week.
I've begun to look into the impute_missing_variable method to smooth in my missing time-series values, but the sts.fit_with_hmc seems incredibly slow. Just running the example on https://www.tensorflow.org/probability/api_docs/python/tfp/sts/impute_missing_values take about 25 seconds.
Is this expected behavior? I'm running this example on CPU and didn't see any performance improvements when I ran the same block of code on GPU.
This is using Tensorflow-Probability==0.19 and Tensorflow==2.11.0
Thanks!
The text was updated successfully, but these errors were encountered: