diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61b9bdfc..ea957468 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,11 +14,11 @@ repos: - --unsafe - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.292 + rev: v0.1.0 hooks: - id: ruff args: ["--fix"] diff --git a/src/chainconsumer/analysis.py b/src/chainconsumer/analysis.py index da67b11e..0dd0ede1 100644 --- a/src/chainconsumer/analysis.py +++ b/src/chainconsumer/analysis.py @@ -53,6 +53,7 @@ def __init__(self, parent: ChainConsumer): SummaryStatistic.MEAN: self.get_parameter_summary_mean, SummaryStatistic.CUMULATIVE: self.get_parameter_summary_cumulative, SummaryStatistic.MAX_CENTRAL: self.get_parameter_summary_max_central, + SummaryStatistic.MEDIAN: self.get_parameter_summary_median, } def get_latex_table( @@ -465,6 +466,12 @@ def get_parameter_summary_max_central(self, chain, parameter): return Bound(lower=xvals[0], center=x, upper=xvals[1]) + def get_parameter_summary_median(self, chain, parameter): + vals = 100 * np.array([0.5 - 0.5 * chain.summary_area, 0.5, 0.5 + 0.5 * chain.summary_area]) + xvals = np.percentile(chain.get_data(parameter), vals) + + return Bound(lower=xvals[0], center=xvals[1], upper=xvals[2]) + if __name__ == "__main__": from .chainconsumer import ChainConsumer diff --git a/src/chainconsumer/statistics.py b/src/chainconsumer/statistics.py index e3c83a71..b5d87f30 100644 --- a/src/chainconsumer/statistics.py +++ b/src/chainconsumer/statistics.py @@ -18,3 +18,7 @@ class SummaryStatistic(Enum): MEAN = "mean" """As per the cumulative method, except the central value is placed in the midpoint between the upper and lower boundary. Not recommended, but was requested.""" + + MEDIAN = "median" + """The central point is set to median of the pdf, and the upper and the upper + and lower bounds are determined by the percentiles of the pdf."""