Skip to content

Commit

Permalink
implement use_standard_error in compute_summary_statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
tbhallett committed Dec 16, 2024
1 parent cdf5b3b commit 3772b4c
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/tlo/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as st
import squarify

from tlo import Date, Simulation, logging, util
Expand Down Expand Up @@ -362,6 +363,7 @@ def compute_summary_statistics(
results: pd.DataFrame,
central_measure: Literal["mean", "median"] = "median",
width_of_range: float = 0.95,
use_standard_error: bool = False,
only_central: bool = False,
collapse_columns: bool = False,
) -> pd.DataFrame:
Expand All @@ -373,6 +375,8 @@ def compute_summary_statistics(
:param results: The dataframe of results to compute summary statistics of.
:param central_measure: The name of the central measure to use - either 'mean' or 'median'.
:param width_of_range: The width of the range to compute the statistics (e.g. 0.95 for the 95% interval).
:param use_standard_error: Whether the range should represent the standard error; otherwise it is just a
description of the variation of runs
:param collapse_columns: Whether to simplify the columnar index if there is only one run (cannot be done otherwise).
:param only_central: Whether to only report the central value (dropping the range).
:return: A dataframe with computed summary statistics.
Expand All @@ -388,9 +392,18 @@ def compute_summary_statistics(
else:
raise ValueError(f"Unknown stat: {central_measure}")

lower_quantile = (1. - width_of_range) / 2.
stats["lower"] = grouped_results.quantile(lower_quantile)
stats["upper"] = grouped_results.quantile(1 - lower_quantile)
if not use_standard_error:
lower_quantile = (1. - width_of_range) / 2.
stats["lower"] = grouped_results.quantile(lower_quantile)
stats["upper"] = grouped_results.quantile(1 - lower_quantile)
else:
# Use standard error concept whereby we're using the intervals to express a 95% CI on the value of the mean.
# This will make width of uncertainty become narrower with more runs.
std_deviation = grouped_results.std()
std_error = std_deviation / np.sqrt(len(grouped_results))
z_value = st.norm.ppf(1 - (1. - width_of_range) / 2.)
stats["lower"] = stats['central'] - z_value * std_error
stats["upper"] = stats['central'] + z_value * std_error

summary = pd.concat(stats, axis=1)
summary.columns = summary.columns.swaplevel(1, 0)
Expand Down

0 comments on commit 3772b4c

Please sign in to comment.