You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Neither the restriction estimator nor the restricted prior are picklable. The issue has come up here before #790 and I opened a discussion #973
Trying to pickle the restriction estimator raises AttributeError: Can't pickle local object 'build_classifier.<locals>.build_nn'
Trying to pickle the restricted prior raises AttributeError: Can't pickle local object 'get_classifier_thresholder.<locals>.classifier_thresholder'
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils import RestrictionEstimator, BoxUniform
from sbi.analysis import pairplot
import torch
import os
import pickle
_ = torch.manual_seed(2)
def simulator(theta):
perturbed_theta = theta + 0.5 * torch.randn(2)
perturbed_theta[theta[:, 0] < 0.0] = torch.as_tensor([float("nan"), float("nan")])
return perturbed_theta
prior = BoxUniform(-2 * torch.ones(2), 2 * torch.ones(2))
theta, x = simulate_for_sbi(simulator, prior, 1000)
print("Simulation outputs: ", x)
restriction_estimator = RestrictionEstimator(prior=prior)
restriction_estimator.append_simulations(theta, x)
classifier = restriction_estimator.train()
restricted_prior = restriction_estimator.restrict_prior()
samples = restricted_prior.sample((10_000,))
_ = pairplot(samples, limits=[[-2, 2], [-2, 2]], fig_size=(4, 4))
new_theta, new_x = simulate_for_sbi(simulator, restricted_prior, 1000)
print("Simulation outputs: ", new_x)
restriction_estimator.append_simulations(
new_theta, new_x
) # Gather the new simulations in the `restriction_estimator`.
(
all_theta,
all_x,
_,
) = restriction_estimator.get_simulations() # Get all simulations run so far.
inference = SNPE(prior=prior)
density_estimator = inference.append_simulations(all_theta, all_x).train()
posterior = inference.build_posterior()
posterior_samples = posterior.sample((10_000,), x=torch.ones(2))
_ = pairplot(posterior_samples, limits=[[-2, 2], [-2, 2]], fig_size=(3, 3))
dirname = os.path.dirname(__file__)
with open(os.path.join(dirname, 'restriction_estimator.pickle'), 'wb') as f:
# Pickle the 'data' dictionary using the highest protocol available.
pickle.dump(restriction_estimator, f, pickle.HIGHEST_PROTOCOL)
Error Message: AttributeError: Can't pickle local object 'build_classifier.<locals>.build_nn'
Expected behavior
Should create the pickle file without error.
Additional context
I suspect the problem has to do with nested functions but I will look into it. I will provide a PR according to contribution guidelines where the restriction_estimator is picklable and the above script generates otherwise the same inference and restriction estimator.
The text was updated successfully, but these errors were encountered:
Describe the bug
Neither the restriction estimator nor the restricted prior are picklable. The issue has come up here before #790 and I opened a discussion #973
Trying to pickle the restriction estimator raises
AttributeError: Can't pickle local object 'build_classifier.<locals>.build_nn'
Trying to pickle the restricted prior raises
AttributeError: Can't pickle local object 'get_classifier_thresholder.<locals>.classifier_thresholder'
To Reproduce
Python 3.9.18 and SBI 0.22.0
This is the example code copied from handling invalid simulations plus
pickle
andos
.Error Message:
AttributeError: Can't pickle local object 'build_classifier.<locals>.build_nn'
Expected behavior
Should create the pickle file without error.
Additional context
I suspect the problem has to do with nested functions but I will look into it. I will provide a PR according to contribution guidelines where the restriction_estimator is picklable and the above script generates otherwise the same inference and restriction estimator.
The text was updated successfully, but these errors were encountered: