diff --git a/nmma/em/analysis.py b/nmma/em/analysis.py index 972871e1..2fde1856 100644 --- a/nmma/em/analysis.py +++ b/nmma/em/analysis.py @@ -24,26 +24,6 @@ from .utils import getFilteredMag, dataProcess from .io import loadEvent -# import functions -from ..mlmodel.dataprocessing import gen_prepend_filler, gen_append_filler, pad_the_data -from ..mlmodel.resnet import ResNet -from ..mlmodel.embedding import SimilarityEmbedding -from ..mlmodel.normalizingflows import normflow_params -from ..mlmodel.inference import cast_as_bilby_result - -# need to add these packages: -import torch -import torch.nn as nn -from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split -import torch.nn.functional as F -from nflows.nn.nets.resnet import ResidualNet -from nflows import transforms, distributions, flows -from nflows.distributions import StandardNormal -from nflows.flows import Flow -from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform -from nflows.transforms import CompositeTransform, RandomPermutation -import nflows.utils as torchutils - matplotlib.use("agg") @@ -1175,6 +1155,26 @@ def analysis(args): def nnanalysis(args): + # import functions + from ..mlmodel.dataprocessing import gen_prepend_filler, gen_append_filler, pad_the_data + from ..mlmodel.resnet import ResNet + from ..mlmodel.embedding import SimilarityEmbedding + from ..mlmodel.normalizingflows import normflow_params + from ..mlmodel.inference import cast_as_bilby_result + + # need to add these packages: + import torch + import torch.nn as nn + from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split + import torch.nn.functional as F + from nflows.nn.nets.resnet import ResidualNet + from nflows import transforms, distributions, flows + from nflows.distributions import StandardNormal + from nflows.flows import Flow + from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform + from nflows.transforms import CompositeTransform, RandomPermutation + import nflows.utils as torchutils + # only continue if the Kasen model is selected if args.model != "Ka2017": print( diff --git a/requirements.txt b/requirements.txt index 9c0bee27..ceefc895 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ pymultinest sncosmo dust_extinction arviz -p_tqdm +p_tqdm<1.4.1 tornado notebook ligo.skymap