Skip to content

Commit

Permalink
hotfix: sobol sensitivity anlaysis hanging bug
Browse files Browse the repository at this point in the history
  • Loading branch information
JBris committed Sep 16, 2024
1 parent 8560cff commit 258c5eb
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 277 deletions.
36 changes: 30 additions & 6 deletions app/flows/run_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,38 @@
# Imports
######################################


import mlflow
from prefect import context, flow, task
import numpy as np
import pandas as pd
import plotly.express as px
from joblib import dump as calibrator_dump
from matplotlib import pyplot as plt
from prefect import flow, task
from prefect.artifacts import create_table_artifact
from prefect.task_runners import ConcurrentTaskRunner

from deeprootgen.data_model import RootCalibrationModel
from deeprootgen.io import save_graph_to_db
from deeprootgen.model import RootSystemSimulation
from deeprootgen.calibration import (
SensitivityAnalysisModel,
calculate_summary_statistic_discrepancy,
get_calibration_summary_stats,
log_model,
)
from deeprootgen.data_model import RootCalibrationModel, SummaryStatisticsModel
from deeprootgen.pipeline import (
begin_experiment,
get_datetime_now,
get_outdir,
log_config,
log_experiment_details,
log_simulation,
)
from deeprootgen.statistics import DistanceMetricBase

######################################
# Constants
######################################

TASK = "abc"

######################################
# Main
Expand All @@ -33,7 +52,12 @@ def run_abc(input_parameters: RootCalibrationModel, simulation_uuid: str) -> Non
simulation_uuid (str):
The simulation uuid.
"""
print("hello")
begin_experiment(TASK, simulation_uuid, input_parameters.simulation_tag)
log_experiment_details(simulation_uuid)

config = input_parameters.dict()
log_config(config, TASK)
mlflow.end_run()


@flow(
Expand Down
19 changes: 14 additions & 5 deletions app/flows/run_sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ def log_task(
for i, si_label in enumerate(["total_si", "first_si", "second_si"]):
outfile = osp.join(outdir, f"{time_now}-{TASK}_{si_label}.csv")
si_df = si_dfs[i]

if si_label != "second_si":
si_df["name"] = names
cols = list(si_df.columns)
cols = [cols[-1]] + cols[:-1]
si_df = si_df[cols]

si_df.to_csv(outfile, index=False)
mlflow.log_artifact(outfile)

Expand All @@ -225,11 +232,6 @@ def log_task(
mlflow.log_artifact(outfile)

total_si_df = si_dfs[0]
total_si_df["name"] = names
cols = list(total_si_df.columns)
cols = [cols[-1]] + cols[:-1]
total_si_df = total_si_df[cols]

fig = px.bar(
total_si_df,
x="name",
Expand All @@ -241,6 +243,13 @@ def log_task(
fig.write_image(outfile, width=1200, height=1200)
mlflow.log_artifact(outfile)

st_confs = total_si_df.ST_conf.values
for i, sensitivity_index in enumerate(total_si_df.ST.values):
name = names[i]
st_conf = st_confs[i]
mlflow.log_metric(name, sensitivity_index)
mlflow.log_metric(f"{name}_conf", st_conf)

create_table_artifact(
key="sensitivity-analysis-indices",
table=total_si_df.to_dict(orient="records"),
Expand Down
6 changes: 3 additions & 3 deletions app/pages/eda_root_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def toggle_external_links_collapse(n: int, is_open: bool) -> bool:


@callback(
Output(f"{PAGE_ID}-observed-data-collapse", "is_open"),
[Input(f"{PAGE_ID}-observed-data-collapse-button", "n_clicks")],
[State(f"{PAGE_ID}-observed-data-collapse", "is_open")],
Output(f"{PAGE_ID}-simulated-data-collapse", "is_open"),
[Input(f"{PAGE_ID}-simulated-data-collapse-button", "n_clicks")],
[State(f"{PAGE_ID}-simulated-data-collapse", "is_open")],
)
def toggle_data_collapse(n: int, is_open: bool) -> bool:
"""Toggle the collapsible for statistics.
Expand Down
2 changes: 2 additions & 0 deletions deeprootgen/calibration/summary_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def run_calibration_simulation(
simulation_tag=input_parameters.simulation_tag, # type: ignore
random_seed=input_parameters.random_seed, # type: ignore
)

simulation.run(simulation_parameters)
return simulation, simulation_parameters

Expand Down Expand Up @@ -113,6 +114,7 @@ def calculate_summary_statistic_discrepancy(
simulation, simulation_parameters = run_calibration_simulation(
parameter_specs, input_parameters
)

node_df, _ = simulation.G.as_df()
observed_values = []
simulated_values = []
Expand Down
9 changes: 7 additions & 2 deletions deeprootgen/form/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ def build_common_components(

if component_spec.handler == "dropdown":
options_func = locate(component_spec.options_func)
summary_statistics = options_func() # type: ignore
component_instance.options = summary_statistics
options_list = options_func() # type: ignore
component_instance.options = options_list
if len(options_list) > 0:
if component_spec.kwargs["multi"]:
component_instance.value = [options_list[0]]
else:
component_instance.value = options_list[0]

if component_spec.handler == "range_slider":
component_instance.value = [
Expand Down
13 changes: 9 additions & 4 deletions deeprootgen/model/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# mypy: ignore-errors

import math
from typing import Dict, List

import networkx as nx
Expand Down Expand Up @@ -274,12 +275,14 @@ def construct_root(
def add_child_organ(
self, floor_threshold: float = 0.4, ceiling_threshold: float = 0.9
) -> "RootOrgan":
floor = int(len(self.segments) * floor_threshold)
ceiling = int(len(self.segments) * ceiling_threshold)
if floor > ceiling:
floor = math.ceil(len(self.segments) * floor_threshold)
ceiling = math.ceil(len(self.segments) * ceiling_threshold)
if floor >= ceiling:
floor, ceiling = ceiling, floor
if floor <= 0:
floor = 1
if floor == ceiling:
ceiling += 1

indx = self.rng.integers(floor, ceiling)
parent_node = self.segments[indx]
Expand Down Expand Up @@ -891,6 +894,7 @@ def init_organs(
order_type=RootType.PRIMARY.value,
position_type=RootType.OUTER.value,
)

for _ in range(input_parameters.outer_root_num):
organ = RootOrgan(
self.G.base_node,
Expand Down Expand Up @@ -939,7 +943,8 @@ def init_organs(
n_secondary_roots = self.rng.integers(
min_sec_root_num, max_sec_root_num
)
n_secondary_roots = int(n_secondary_roots * growth_sec_root)

n_secondary_roots = math.ceil(n_secondary_roots * growth_sec_root)
for _ in range(n_secondary_roots):
child_organ = parent_organ.add_child_organ(
floor_threshold=input_parameters.floor_threshold,
Expand Down
Loading

0 comments on commit 258c5eb

Please sign in to comment.