Skip to content
Merged
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ See below for the available methods of inference, `SNPE`, `SNRE` and `SNLE`.

## Installation

`sbi` requires Python 3.7 or higher. It can be installed using `pip`:
`sbi` requires Python 3.6 or higher. It can be installed using `pip`:
```commandline
$ pip install sbi
```
Expand All @@ -35,8 +35,8 @@ We recommend to use a [`conda`](https://docs.conda.io/en/latest/miniconda.html)
environment ([Miniconda installation instructions](https://docs.conda.io/en/latest/miniconda.html])). If `conda` is installed on the system, an environment for
installing `sbi` can be created as follows:
```commandline
# Create an environment for sbi (indicate Python 3.7 or higher); activate it
$ conda create -n sbi_env python=3.7 && conda activate sbi_env
# Create an environment for sbi (indicate Python 3.6 or higher); activate it
$ conda create -n sbi_env python=3.6 && conda activate sbi_env
```

To test the installation, drop into a python prompt and run
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Type # noqa: F401
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
from sbi.inference.base import NeuralInference, infer # noqa: F401
from sbi.user_input.user_input_checks import prepare_for_sbi

Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/abc/abc_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from abc import ABC
from typing import Callable, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from numpy import ndarray
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from numpy import ndarray
Expand Down
4 changes: 1 addition & 3 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from abc import ABC
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
from warnings import warn

import torch
Expand Down
5 changes: 2 additions & 3 deletions sbi/inference/posterior.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Callable, Optional
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
from warnings import warn

import numpy as np
Expand All @@ -14,7 +13,7 @@
from torch import multiprocessing as mp
from torch import nn

import sbi.utils as utils
from sbi import utils as utils
from sbi.mcmc import Slice, SliceSampler
from sbi.types import Array, Shape
from sbi.user_input.user_input_checks import process_x
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/snle/snle_a.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor
Expand Down
5 changes: 2 additions & 3 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from abc import ABC
from typing import Callable, Dict, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import numpy as np
import torch
Expand All @@ -14,7 +13,7 @@
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

import sbi.utils as utils
from sbi import utils as utils
from sbi.inference import NeuralInference
from sbi.inference.posterior import NeuralPosterior
from sbi.types import OneOrMore, ScalarFloat
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import nn
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/snpe/snpe_b.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor, nn
Expand Down
5 changes: 2 additions & 3 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from abc import ABC, abstractmethod
from copy import deepcopy
from sbi.user_input.user_input_checks import check_estimator_arg
from typing import Callable, Dict, Optional, Tuple, Union, cast
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
from warnings import warn

import numpy as np
Expand All @@ -17,7 +16,7 @@
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

import sbi.utils as utils
from sbi import utils as utils
from sbi.inference import NeuralInference
from sbi.inference.posterior import NeuralPosterior
from sbi.types import OneOrMore, ScalarFloat
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/snpe/snpe_c.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor, eye, nn, ones
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/snre/snre_a.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor, nn, ones
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/snre/snre_b.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor
Expand Down
5 changes: 2 additions & 3 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import numpy as np
import torch
Expand All @@ -11,7 +10,7 @@
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

import sbi.utils as utils
from sbi import utils as utils
from sbi.inference.base import NeuralInference
from sbi.inference.posterior import NeuralPosterior
from sbi.types import OneOrMore, ScalarFloat
Expand Down
1 change: 0 additions & 1 deletion sbi/neural_nets/classifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import torch
from pyknos.nflows.nn import nets
Expand Down
1 change: 0 additions & 1 deletion sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from warnings import warn

Expand Down
1 change: 0 additions & 1 deletion sbi/neural_nets/mdn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from pyknos.mdn.mdn import MultivariateGaussianMDN
from pyknos.nflows import flows, transforms
Expand Down
3 changes: 1 addition & 2 deletions sbi/simulators/linear_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Union, Tuple
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor
Expand Down
3 changes: 1 addition & 2 deletions sbi/simulators/simutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Callable
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from torch import Tensor
Expand Down
3 changes: 1 addition & 2 deletions sbi/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Sequence, Union, Tuple, TypeVar
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
import numpy as np
import torch

Expand Down
3 changes: 1 addition & 2 deletions sbi/user_input/user_input_checks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import warnings
from typing import Callable, Optional, Sequence, Tuple, Union, cast
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from numpy import ndarray
Expand Down
5 changes: 2 additions & 3 deletions sbi/user_input/user_input_checks_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import warnings
from typing import Optional, Sequence, Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
from scipy.stats._distn_infrastructure import rv_frozen
Expand Down Expand Up @@ -171,7 +170,7 @@ class MultipleIndependent(Distribution):

def __init__(
self, dists: Sequence[Distribution], validate_args=None,
) -> MultipleIndependent:
):
self._check_distributions(dists)

self.dists = dists
Expand Down
3 changes: 1 addition & 2 deletions sbi/utils/get_nn_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Any, Callable
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
from sbi.neural_nets.flow import build_made, build_maf, build_nsf
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.classifier import (
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch

from typing import Optional
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
from torch import Tensor

from sklearn.model_selection import KFold, cross_val_score
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import collections
import inspect
from typing import Optional, Tuple, Union, List
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import numpy as np
import six
from scipy.stats import gaussian_kde
Expand Down
6 changes: 3 additions & 3 deletions sbi/utils/pyroutils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Any
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import pyro.distributions as dist
import pyro.poutine as poutine
from pyro import distributions as dist
from pyro import poutine as poutine
from torch.distributions import biject_to


Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

import logging
from typing import Any, Dict, List, Sequence, Tuple
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar

import torch
import torch.nn as nn
from torch import nn as nn
from pyknos.nflows import transforms
from torch import Tensor, as_tensor, ones, zeros
from tqdm.auto import tqdm
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch
from torch import Tensor, float32, device
from torch.distributions import Independent, Uniform
from typing import Union
from typing import Callable, Optional, Union, Dict, Any, Tuple, Union, cast, List, Sequence, TypeVar
import warnings

import sbi.utils as utils
from sbi import utils as utils
from sbi.types import Array, OneOrMore, ScalarFloat


Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
URL = "https://github.com/mackelab/sbi"
EMAIL = "sbi@mackelab.org"
AUTHOR = "Álvaro Tejero-Cantero, Jakob H. Macke, Jan-Matthis Lückmann, Conor M. Durkan, Michael Deistler, Jan Bölts"
REQUIRES_PYTHON = ">=3.7.0"
REQUIRES_PYTHON = ">=3.6.0"

REQUIRED = [
"joblib",
Expand All @@ -29,6 +29,7 @@
"pillow",
"pyro-ppl",
"pyknos==0.11",
"nflows==0.12", #Remove once pyknos is updated to 0.12
"scipy",
"tensorboard",
"torch>=1.5.1",
Expand Down
2 changes: 1 addition & 1 deletion tests/inference_with_NaN_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sbi.inference import SNPE_C, SRE, SNL, prepare_for_sbi
from torch import zeros, ones, eye

import sbi.utils as utils
from sbi import utils as utils

from sbi.simulators.linear_gaussian import (
samples_true_posterior_linear_gaussian_uniform_prior,
Expand Down
Loading