From 8a8456924f4596c1a24e07477c1897259111a4cd Mon Sep 17 00:00:00 2001 From: James Bristow Date: Mon, 16 Sep 2024 17:57:30 +1200 Subject: [PATCH] adding pyabc --- app/flows/run_abc.py | 86 ++++++++++++++++++- app/pages/abc_root_system.py | 21 ++++- app/pages/optimisation_root_system.py | 21 ++++- app/pages/sensitivity_analysis_root_system.py | 21 ++++- poetry.lock | 75 +++++++++++++++- pyproject.toml | 2 + 6 files changed, 220 insertions(+), 6 deletions(-) diff --git a/app/flows/run_abc.py b/app/flows/run_abc.py index 48e297f..6581a95 100644 --- a/app/flows/run_abc.py +++ b/app/flows/run_abc.py @@ -9,11 +9,12 @@ import numpy as np import pandas as pd import plotly.express as px +import pymc as pm from joblib import dump as calibrator_dump from matplotlib import pyplot as plt from prefect import flow, task from prefect.artifacts import create_table_artifact -from prefect.task_runners import ConcurrentTaskRunner +from prefect.task_runners import SequentialTaskRunner from deeprootgen.calibration import ( SensitivityAnalysisModel, @@ -42,6 +43,11 @@ ###################################### +def distance_func(e: float, observed: np.ndarray, simulated: np.ndarray) -> float: + print(simulated) + return simulated.item() + + @task def run_abc(input_parameters: RootCalibrationModel, simulation_uuid: str) -> None: """Running Approximate Bayesian Computation. @@ -55,6 +61,9 @@ def run_abc(input_parameters: RootCalibrationModel, simulation_uuid: str) -> Non begin_experiment(TASK, simulation_uuid, input_parameters.simulation_tag) log_experiment_details(simulation_uuid) + # distance, statistics_list = get_calibration_summary_stats(input_parameters) + # [statistic.statistic_value for statistic in statistics_list] + config = input_parameters.dict() log_config(config, TASK) mlflow.end_run() @@ -63,7 +72,7 @@ def run_abc(input_parameters: RootCalibrationModel, simulation_uuid: str) -> Non @flow( name="abc", description="Perform Bayesian parameter estimation for the root model using Approximate Bayesian Computation.", - task_runner=ConcurrentTaskRunner(), + task_runner=SequentialTaskRunner(), ) def run_abc_flow(input_parameters: RootCalibrationModel, simulation_uuid: str) -> None: """Flow for running Approximate Bayesian Computation. @@ -75,3 +84,76 @@ def run_abc_flow(input_parameters: RootCalibrationModel, simulation_uuid: str) - The simulation uuid. """ run_abc.submit(input_parameters, simulation_uuid) + + # names = [] + # priors = [] + # with pm.Model() as model: + # parameter_intervals = input_parameters.parameter_intervals.dict() + # for name, v in parameter_intervals.items(): + # names.append(name) + + # lower_bound = v["lower_bound"] + # upper_bound = v["upper_bound"] + # data_type = v["data_type"] + + # if data_type == "discrete": + # prior = pm.DiscreteUniform(name, lower_bound, upper_bound) + # else: + # prior = pm.Uniform(name, lower_bound, upper_bound) + # priors.append(prior) + + # params = tuple(priors) + + # def simulator_func( + # _, *parameters + # ): + # parameters = parameters[:-1] + + # parameter_specs = {} + # for i, name in enumerate(names): + # parameter_specs[name] = parameters[i].item() + + # discrepancy = calculate_summary_statistic_discrepancy( + # parameter_specs, input_parameters, statistics_list, distance + # ) + + # print(parameter_specs) + # print(discrepancy) + # return np.array([discrepancy]) + + # calibration_parameters = input_parameters.calibration_parameters + # pm.Simulator( + # "root_simulator", + # simulator_func, + # params = params, + # distance = distance_func, + # sum_stat = "identity", + # epsilon = calibration_parameters["epsilon"], + # observed = observed_values, + # ) + + # time_now = get_datetime_now() + # outdir = get_outdir() + # pgm = pm.model_to_graphviz(model = model) + # outfile = f"{time_now}-{TASK}_model_graph" + # pgm.render(format = "png", directory = outdir, filename = outfile) + # outfile = osp.join(outdir, f"{outfile}.png") + # mlflow.log_artifact(outfile) + + # trace = pm.sample_smc( + # # draws = calibration_parameters["draws"], + # model = model, + # draws = 3, + # chains = 1, + # cores = 1, + # # chains = calibration_parameters["chains"], + # # cores = calibration_parameters["cores"], + # # compute_convergence_checks = False, + # # return_inferencedata = False, + # # random_seed = input_parameters.random_seed, + # # progressbar = False + # ) + + # del trace + # del model + # print(trace) diff --git a/app/pages/abc_root_system.py b/app/pages/abc_root_system.py index bc5fec9..228d0fc 100644 --- a/app/pages/abc_root_system.py +++ b/app/pages/abc_root_system.py @@ -359,13 +359,21 @@ def toggle_calibration_parameters_collapse(n: int, is_open: bool) -> bool: Output({"index": f"{PAGE_ID}-run-sim-button", "type": ALL}, "disabled"), Output({"index": f"{PAGE_ID}-clear-obs-data-file-button", "type": ALL}, "disabled"), Input("store-summary-data", "data"), + Input({"index": f"{PAGE_ID}-select-summary-stats-dropdown", "type": ALL}, "value"), + Input({"index": f"{PAGE_ID}-distance-dropdown", "type": ALL}, "value"), ) -def update_summary_data_state(summary_data: dict | None) -> tuple: +def update_summary_data_state( + summary_data: dict | None, summary_stats: list, distances: list +) -> tuple: """Update the state of the summary data. Args: summary_data (dict | None): The summary data. + summary_stats (list): + The list of summary statistics. + distances (list): + The list of distance metrics. Returns: tuple: @@ -379,6 +387,17 @@ def update_summary_data_state(summary_data: dict | None) -> tuple: if summary_label is None: return button_contents, [True], [True] + if summary_stats is None or distances is None: + return [summary_label], [True], [True] + + summary_stats_list = summary_stats[0] + distance_list = distances[0] + if len(summary_stats_list) == 0 or distance_list is None: + return [summary_label], [True], [True] + + if summary_stats_list[0] is None or distance_list == "": + return [summary_label], [True], [True] + return [summary_label], [False], [False] diff --git a/app/pages/optimisation_root_system.py b/app/pages/optimisation_root_system.py index dc9596a..f4793ad 100644 --- a/app/pages/optimisation_root_system.py +++ b/app/pages/optimisation_root_system.py @@ -359,13 +359,21 @@ def toggle_calibration_parameters_collapse(n: int, is_open: bool) -> bool: Output({"index": f"{PAGE_ID}-run-sim-button", "type": ALL}, "disabled"), Output({"index": f"{PAGE_ID}-clear-obs-data-file-button", "type": ALL}, "disabled"), Input("store-summary-data", "data"), + Input({"index": f"{PAGE_ID}-select-summary-stats-dropdown", "type": ALL}, "value"), + Input({"index": f"{PAGE_ID}-distance-dropdown", "type": ALL}, "value"), ) -def update_summary_data_state(summary_data: dict | None) -> tuple: +def update_summary_data_state( + summary_data: dict | None, summary_stats: list, distances: list +) -> tuple: """Update the state of the summary data. Args: summary_data (dict | None): The summary data. + summary_stats (list): + The list of summary statistics. + distances (list): + The list of distance metrics. Returns: tuple: @@ -379,6 +387,17 @@ def update_summary_data_state(summary_data: dict | None) -> tuple: if summary_label is None: return button_contents, [True], [True] + if summary_stats is None or distances is None: + return [summary_label], [True], [True] + + summary_stats_list = summary_stats[0] + distance_list = distances[0] + if len(summary_stats_list) == 0 or distance_list is None: + return [summary_label], [True], [True] + + if summary_stats_list[0] is None or distance_list == "": + return [summary_label], [True], [True] + return [summary_label], [False], [False] diff --git a/app/pages/sensitivity_analysis_root_system.py b/app/pages/sensitivity_analysis_root_system.py index 94e9ee6..f954a3a 100644 --- a/app/pages/sensitivity_analysis_root_system.py +++ b/app/pages/sensitivity_analysis_root_system.py @@ -359,13 +359,21 @@ def toggle_calibration_parameters_collapse(n: int, is_open: bool) -> bool: Output({"index": f"{PAGE_ID}-run-sim-button", "type": ALL}, "disabled"), Output({"index": f"{PAGE_ID}-clear-obs-data-file-button", "type": ALL}, "disabled"), Input("store-summary-data", "data"), + Input({"index": f"{PAGE_ID}-select-summary-stats-dropdown", "type": ALL}, "value"), + Input({"index": f"{PAGE_ID}-distance-dropdown", "type": ALL}, "value"), ) -def update_summary_data_state(summary_data: dict | None) -> tuple: +def update_summary_data_state( + summary_data: dict | None, summary_stats: list, distances: list +) -> tuple: """Update the state of the summary data. Args: summary_data (dict | None): The summary data. + summary_stats (list): + The list of summary statistics. + distances (list): + The list of distance metrics. Returns: tuple: @@ -379,6 +387,17 @@ def update_summary_data_state(summary_data: dict | None) -> tuple: if summary_label is None: return button_contents, [True], [True] + if summary_stats is None or distances is None: + return [summary_label], [True], [True] + + summary_stats_list = summary_stats[0] + distance_list = distances[0] + if len(summary_stats_list) == 0 or distance_list is None: + return [summary_label], [True], [True] + + if summary_stats_list[0] is None or distance_list == "": + return [summary_label], [True], [True] + return [summary_label], [False], [False] diff --git a/poetry.lock b/poetry.lock index c673d3e..9c60388 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2802,6 +2802,17 @@ files = [ {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, ] +[[package]] +name = "jabbar" +version = "0.0.16" +description = "Just Another Beautiful progress BAR" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jabbar-0.0.16-py3-none-any.whl", hash = "sha256:50d4392202b32a3781a8626e6895b893acf54c5e1ac1c2340b70f3d71f707d04"}, + {file = "jabbar-0.0.16.tar.gz", hash = "sha256:522f29ca04e44a25fbc3ae0419f7d7bf96af1b9d131bbd8b58899224fc5eb0f5"}, +] + [[package]] name = "jedi" version = "0.19.1" @@ -5152,6 +5163,50 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyabc" +version = "0.12.13" +description = "Distributed, likelihood-free ABC-SMC inference" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyabc-0.12.13-py3-none-any.whl", hash = "sha256:84c76e8d03edea3d67c102198f2d43d849043a5a09610370af01d38a81e3140a"}, + {file = "pyabc-0.12.13.tar.gz", hash = "sha256:ded2fc6b08432d33eb4531fe9863d596adf53dcd036f0080becb21c4b046bea7"}, +] + +[package.dependencies] +click = ">=7.1.2" +cloudpickle = ">=1.5.0" +distributed = ">=2022.10.2" +gitpython = ">=3.1.7" +jabbar = ">=0.0.10" +matplotlib = ">=3.3.0" +numpy = ">=1.19.1" +pandas = ">=2.0.1" +redis = ">=2.10.6" +scikit-learn = ">=0.23.1" +scipy = ">=1.5.2" +sqlalchemy = ">=2.0.12" + +[package.extras] +amici = ["amici (>=0.18.0)"] +autograd = ["autograd (>=1.3)"] +copasi = ["copasi-basico (>=0.8)"] +doc = ["ipython (>=8.4.0)", "nbconvert (>=6.5.0)", "nbsphinx (>=0.8.9)", "sphinx (>=6.2.1)", "sphinx-autodoc-typehints (>=1.18.3)", "sphinx-rtd-theme (>=1.2.0)"] +examples = ["notebook (>=6.1.4)"] +julia = ["julia (>=0.5.7)", "pygments (>=2.6.1)"] +migrate = ["alembic (>=1.5.4)"] +ot = ["pot (>=0.7.0)"] +petab = ["petab (>=0.2.0)"] +plotly = ["kaleido (>=0.2.1)", "plotly (>=5.3.1)"] +pyarrow = ["pyarrow (>=6.0.0)"] +r = ["cffi (>=1.14.5)", "ipython (>=7.18.1)", "pygments (>=2.6.1)", "rpy2 (>=3.4.4)"] +test = ["pytest (>=5.4.3)", "pytest-cov (>=2.10.0)", "pytest-rerunfailures (>=9.1.1)"] +test-petab = ["petabtests (>=0.0.0a6)"] +webserver-dash = ["dash (>=2.11.1)", "dash-bootstrap-components (>=1.4.2)"] +webserver-flask = ["bokeh (>=3.0.1)", "flask (>=1.1.2)", "flask-bootstrap (>=3.3.7.1)"] +yaml2sbml = ["yaml2sbml (>=0.2.1)"] + [[package]] name = "pyarrow" version = "15.0.2" @@ -6059,6 +6114,24 @@ files = [ {file = "readchar-4.2.0.tar.gz", hash = "sha256:44807cbbe377b72079fea6cba8aa91c809982d7d727b2f0dbb2d1a8084914faa"}, ] +[[package]] +name = "redis" +version = "5.0.8" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.8-py3-none-any.whl", hash = "sha256:56134ee08ea909106090934adc36f65c9bcbbaecea5b21ba704ba6fb561f8eb4"}, + {file = "redis-5.0.8.tar.gz", hash = "sha256:0c5b10d387568dfe0698c6fad6615750c24170e548ca2deac10c649d463e9870"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "referencing" version = "0.35.1" @@ -8380,4 +8453,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "95d6a8c75238eda2219b60c8cb7401fc74e7cc121bf5a7a3eaaa2490152552f9" +content-hash = "6bcfc75618c28b6af9dd2994cfc656b418c44ad303645f8fcab8634415b3afbc" diff --git a/pyproject.toml b/pyproject.toml index 8abf3b1..f179bd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ dash-daq = "^0.5.0" ydata-profiling = "^4.10.0" adbnx-adapter = "^5.0.3" python-arango = "<8.0" +pyabc = "^0.12.13" +redis = "^5.0.8" [tool.poetry.group.torch.dependencies]