Skip to content

Commit

Permalink
adding pyabc
Browse files Browse the repository at this point in the history
  • Loading branch information
JBris committed Sep 16, 2024
1 parent 258c5eb commit 8a84569
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 6 deletions.
86 changes: 84 additions & 2 deletions app/flows/run_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import numpy as np
import pandas as pd
import plotly.express as px
import pymc as pm
from joblib import dump as calibrator_dump
from matplotlib import pyplot as plt
from prefect import flow, task
from prefect.artifacts import create_table_artifact
from prefect.task_runners import ConcurrentTaskRunner
from prefect.task_runners import SequentialTaskRunner

from deeprootgen.calibration import (
SensitivityAnalysisModel,
Expand Down Expand Up @@ -42,6 +43,11 @@
######################################


def distance_func(e: float, observed: np.ndarray, simulated: np.ndarray) -> float:
print(simulated)
return simulated.item()


@task
def run_abc(input_parameters: RootCalibrationModel, simulation_uuid: str) -> None:
"""Running Approximate Bayesian Computation.
Expand All @@ -55,6 +61,9 @@ def run_abc(input_parameters: RootCalibrationModel, simulation_uuid: str) -> Non
begin_experiment(TASK, simulation_uuid, input_parameters.simulation_tag)
log_experiment_details(simulation_uuid)

# distance, statistics_list = get_calibration_summary_stats(input_parameters)
# [statistic.statistic_value for statistic in statistics_list]

config = input_parameters.dict()
log_config(config, TASK)
mlflow.end_run()
Expand All @@ -63,7 +72,7 @@ def run_abc(input_parameters: RootCalibrationModel, simulation_uuid: str) -> Non
@flow(
name="abc",
description="Perform Bayesian parameter estimation for the root model using Approximate Bayesian Computation.",
task_runner=ConcurrentTaskRunner(),
task_runner=SequentialTaskRunner(),
)
def run_abc_flow(input_parameters: RootCalibrationModel, simulation_uuid: str) -> None:
"""Flow for running Approximate Bayesian Computation.
Expand All @@ -75,3 +84,76 @@ def run_abc_flow(input_parameters: RootCalibrationModel, simulation_uuid: str) -
The simulation uuid.
"""
run_abc.submit(input_parameters, simulation_uuid)

# names = []
# priors = []
# with pm.Model() as model:
# parameter_intervals = input_parameters.parameter_intervals.dict()
# for name, v in parameter_intervals.items():
# names.append(name)

# lower_bound = v["lower_bound"]
# upper_bound = v["upper_bound"]
# data_type = v["data_type"]

# if data_type == "discrete":
# prior = pm.DiscreteUniform(name, lower_bound, upper_bound)
# else:
# prior = pm.Uniform(name, lower_bound, upper_bound)
# priors.append(prior)

# params = tuple(priors)

# def simulator_func(
# _, *parameters
# ):
# parameters = parameters[:-1]

# parameter_specs = {}
# for i, name in enumerate(names):
# parameter_specs[name] = parameters[i].item()

# discrepancy = calculate_summary_statistic_discrepancy(
# parameter_specs, input_parameters, statistics_list, distance
# )

# print(parameter_specs)
# print(discrepancy)
# return np.array([discrepancy])

# calibration_parameters = input_parameters.calibration_parameters
# pm.Simulator(
# "root_simulator",
# simulator_func,
# params = params,
# distance = distance_func,
# sum_stat = "identity",
# epsilon = calibration_parameters["epsilon"],
# observed = observed_values,
# )

# time_now = get_datetime_now()
# outdir = get_outdir()
# pgm = pm.model_to_graphviz(model = model)
# outfile = f"{time_now}-{TASK}_model_graph"
# pgm.render(format = "png", directory = outdir, filename = outfile)
# outfile = osp.join(outdir, f"{outfile}.png")
# mlflow.log_artifact(outfile)

# trace = pm.sample_smc(
# # draws = calibration_parameters["draws"],
# model = model,
# draws = 3,
# chains = 1,
# cores = 1,
# # chains = calibration_parameters["chains"],
# # cores = calibration_parameters["cores"],
# # compute_convergence_checks = False,
# # return_inferencedata = False,
# # random_seed = input_parameters.random_seed,
# # progressbar = False
# )

# del trace
# del model
# print(trace)
21 changes: 20 additions & 1 deletion app/pages/abc_root_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,21 @@ def toggle_calibration_parameters_collapse(n: int, is_open: bool) -> bool:
Output({"index": f"{PAGE_ID}-run-sim-button", "type": ALL}, "disabled"),
Output({"index": f"{PAGE_ID}-clear-obs-data-file-button", "type": ALL}, "disabled"),
Input("store-summary-data", "data"),
Input({"index": f"{PAGE_ID}-select-summary-stats-dropdown", "type": ALL}, "value"),
Input({"index": f"{PAGE_ID}-distance-dropdown", "type": ALL}, "value"),
)
def update_summary_data_state(summary_data: dict | None) -> tuple:
def update_summary_data_state(
summary_data: dict | None, summary_stats: list, distances: list
) -> tuple:
"""Update the state of the summary data.
Args:
summary_data (dict | None):
The summary data.
summary_stats (list):
The list of summary statistics.
distances (list):
The list of distance metrics.
Returns:
tuple:
Expand All @@ -379,6 +387,17 @@ def update_summary_data_state(summary_data: dict | None) -> tuple:
if summary_label is None:
return button_contents, [True], [True]

if summary_stats is None or distances is None:
return [summary_label], [True], [True]

summary_stats_list = summary_stats[0]
distance_list = distances[0]
if len(summary_stats_list) == 0 or distance_list is None:
return [summary_label], [True], [True]

if summary_stats_list[0] is None or distance_list == "":
return [summary_label], [True], [True]

return [summary_label], [False], [False]


Expand Down
21 changes: 20 additions & 1 deletion app/pages/optimisation_root_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,21 @@ def toggle_calibration_parameters_collapse(n: int, is_open: bool) -> bool:
Output({"index": f"{PAGE_ID}-run-sim-button", "type": ALL}, "disabled"),
Output({"index": f"{PAGE_ID}-clear-obs-data-file-button", "type": ALL}, "disabled"),
Input("store-summary-data", "data"),
Input({"index": f"{PAGE_ID}-select-summary-stats-dropdown", "type": ALL}, "value"),
Input({"index": f"{PAGE_ID}-distance-dropdown", "type": ALL}, "value"),
)
def update_summary_data_state(summary_data: dict | None) -> tuple:
def update_summary_data_state(
summary_data: dict | None, summary_stats: list, distances: list
) -> tuple:
"""Update the state of the summary data.
Args:
summary_data (dict | None):
The summary data.
summary_stats (list):
The list of summary statistics.
distances (list):
The list of distance metrics.
Returns:
tuple:
Expand All @@ -379,6 +387,17 @@ def update_summary_data_state(summary_data: dict | None) -> tuple:
if summary_label is None:
return button_contents, [True], [True]

if summary_stats is None or distances is None:
return [summary_label], [True], [True]

summary_stats_list = summary_stats[0]
distance_list = distances[0]
if len(summary_stats_list) == 0 or distance_list is None:
return [summary_label], [True], [True]

if summary_stats_list[0] is None or distance_list == "":
return [summary_label], [True], [True]

return [summary_label], [False], [False]


Expand Down
21 changes: 20 additions & 1 deletion app/pages/sensitivity_analysis_root_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,21 @@ def toggle_calibration_parameters_collapse(n: int, is_open: bool) -> bool:
Output({"index": f"{PAGE_ID}-run-sim-button", "type": ALL}, "disabled"),
Output({"index": f"{PAGE_ID}-clear-obs-data-file-button", "type": ALL}, "disabled"),
Input("store-summary-data", "data"),
Input({"index": f"{PAGE_ID}-select-summary-stats-dropdown", "type": ALL}, "value"),
Input({"index": f"{PAGE_ID}-distance-dropdown", "type": ALL}, "value"),
)
def update_summary_data_state(summary_data: dict | None) -> tuple:
def update_summary_data_state(
summary_data: dict | None, summary_stats: list, distances: list
) -> tuple:
"""Update the state of the summary data.
Args:
summary_data (dict | None):
The summary data.
summary_stats (list):
The list of summary statistics.
distances (list):
The list of distance metrics.
Returns:
tuple:
Expand All @@ -379,6 +387,17 @@ def update_summary_data_state(summary_data: dict | None) -> tuple:
if summary_label is None:
return button_contents, [True], [True]

if summary_stats is None or distances is None:
return [summary_label], [True], [True]

summary_stats_list = summary_stats[0]
distance_list = distances[0]
if len(summary_stats_list) == 0 or distance_list is None:
return [summary_label], [True], [True]

if summary_stats_list[0] is None or distance_list == "":
return [summary_label], [True], [True]

return [summary_label], [False], [False]


Expand Down
75 changes: 74 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ dash-daq = "^0.5.0"
ydata-profiling = "^4.10.0"
adbnx-adapter = "^5.0.3"
python-arango = "<8.0"
pyabc = "^0.12.13"
redis = "^5.0.8"


[tool.poetry.group.torch.dependencies]
Expand Down

0 comments on commit 8a84569

Please sign in to comment.