Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restricted prior is not picklable #975

Closed
danielmk opened this issue Mar 5, 2024 · 0 comments
Closed

restricted prior is not picklable #975

danielmk opened this issue Mar 5, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@danielmk
Copy link
Contributor

danielmk commented Mar 5, 2024

Describe the bug
Neither the restriction estimator nor the restricted prior are picklable. The issue has come up here before #790 and I opened a discussion #973

Trying to pickle the restriction estimator raises AttributeError: Can't pickle local object 'build_classifier.<locals>.build_nn'

Trying to pickle the restricted prior raises AttributeError: Can't pickle local object 'get_classifier_thresholder.<locals>.classifier_thresholder'

To Reproduce
Python 3.9.18 and SBI 0.22.0

This is the example code copied from handling invalid simulations plus pickle and os.

from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils import RestrictionEstimator, BoxUniform
from sbi.analysis import pairplot
import torch
import os
import pickle

_ = torch.manual_seed(2)

def simulator(theta):
    perturbed_theta = theta + 0.5 * torch.randn(2)
    perturbed_theta[theta[:, 0] < 0.0] = torch.as_tensor([float("nan"), float("nan")])
    return perturbed_theta

prior = BoxUniform(-2 * torch.ones(2), 2 * torch.ones(2))

theta, x = simulate_for_sbi(simulator, prior, 1000)
print("Simulation outputs: ", x)

restriction_estimator = RestrictionEstimator(prior=prior)

restriction_estimator.append_simulations(theta, x)
classifier = restriction_estimator.train()

restricted_prior = restriction_estimator.restrict_prior()
samples = restricted_prior.sample((10_000,))
_ = pairplot(samples, limits=[[-2, 2], [-2, 2]], fig_size=(4, 4))

new_theta, new_x = simulate_for_sbi(simulator, restricted_prior, 1000)
print("Simulation outputs: ", new_x)

restriction_estimator.append_simulations(
    new_theta, new_x
)  # Gather the new simulations in the `restriction_estimator`.
(
    all_theta,
    all_x,
    _,
) = restriction_estimator.get_simulations()  # Get all simulations run so far.

inference = SNPE(prior=prior)
density_estimator = inference.append_simulations(all_theta, all_x).train()
posterior = inference.build_posterior()

posterior_samples = posterior.sample((10_000,), x=torch.ones(2))
_ = pairplot(posterior_samples, limits=[[-2, 2], [-2, 2]], fig_size=(3, 3))

dirname = os.path.dirname(__file__)

with open(os.path.join(dirname, 'restriction_estimator.pickle'), 'wb') as f:
    # Pickle the 'data' dictionary using the highest protocol available.
    pickle.dump(restriction_estimator, f, pickle.HIGHEST_PROTOCOL)

Error Message:
AttributeError: Can't pickle local object 'build_classifier.<locals>.build_nn'

Expected behavior
Should create the pickle file without error.

Additional context
I suspect the problem has to do with nested functions but I will look into it. I will provide a PR according to contribution guidelines where the restriction_estimator is picklable and the above script generates otherwise the same inference and restriction estimator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants