Skip to content

Commit

Permalink
adding stats list to snpe deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
JBris committed Sep 18, 2024
1 parent 3ab80c9 commit 91694e0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
6 changes: 6 additions & 0 deletions app/flows/run_snpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def log_task(
posterior_samples: torch.Tensor,
input_parameters: RootCalibrationModel,
observed_values: list,
statistics_list: list[SummaryStatisticsModel],
names: list[str],
limits: list[tuple],
simulation_uuid: str,
Expand All @@ -220,6 +221,8 @@ def log_task(
The root calibration data model.
observed_values (list):
The list of observed_values.
statistics_list (list[SummaryStatisticsModel]):
The list of summary statistics.
names (list[str]):
The parameter names.
limits (list[tuple]):
Expand Down Expand Up @@ -327,12 +330,14 @@ def log_task(
parameter_specs, input_parameters
)

statistics_list = [statistic.dict() for statistic in statistics_list]
parameter_intervals["inference_type"] = "summary_statistics"
artifacts = {}
for obj, name in [
(inference, "inference"),
(posterior, "posterior"),
(parameter_intervals, "parameter_intervals"),
(statistics_list, "statistics_list"),
]:
outfile = osp.join(outdir, f"{time_now}-{TASK}_{name}.pkl")
artifacts[name] = outfile
Expand Down Expand Up @@ -380,6 +385,7 @@ def run_snpe(input_parameters: RootCalibrationModel, simulation_uuid: str) -> No
posterior_samples,
input_parameters,
observed_values,
statistics_list,
names,
limits,
simulation_uuid,
Expand Down
30 changes: 17 additions & 13 deletions deeprootgen/calibration/model_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
This module defines MLflow compatible models for versioning and deployment as microservices.
"""

from typing import Any

import bentoml
import mlflow
import numpy as np
Expand Down Expand Up @@ -253,14 +255,14 @@ def load_context(self, context: Context) -> None:
"""
import joblib

loaded_data = context.artifacts["inference"]
self.inference = joblib.load(loaded_data)

loaded_data = context.artifacts["posterior"]
self.posterior = joblib.load(loaded_data)
def load_data(k: str) -> Any:
artifact = context.artifacts[k]
return joblib.load(artifact)

loaded_data = context.artifacts["parameter_intervals"]
self.parameter_intervals = joblib.load(loaded_data)
self.inference = load_data("inference")
self.posterior = load_data("posterior")
self.parameter_intervals = load_data("parameter_intervals")
self.statistics_list = load_data("statistics_list")

def predict(
self, context: Context, model_input: pd.DataFrame, params: dict | None = None
Expand All @@ -283,12 +285,14 @@ def predict(
pd.DataFrame:
The model prediction.
"""
if (
self.inference is None
or self.posterior is None
or self.parameter_intervals is None
):
raise ValueError(f"The {self.task} calibrator has not been loaded.")
for prop in [
self.inference,
self.posterior,
self.parameter_intervals,
self.statistics_list,
]:
if prop is None:
raise ValueError(f"The {self.task} calibrator has not been loaded.")

observed_values = model_input["statistic_value"].values
posterior_samples = self.posterior.sample((50,), x=observed_values)
Expand Down

0 comments on commit 91694e0

Please sign in to comment.