Skip to content

Commit 6687c59

Browse files
committed
fix: changed errorbar to be defined from the test distribution mean.
feat: added option to plot test distribution's mean. styl: homogenized plot_poisson_consistency_test and plot_consistency_test
1 parent 241f3f3 commit 6687c59

File tree

1 file changed

+39
-18
lines changed

1 file changed

+39
-18
lines changed

csep/utils/plots.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15551555
figsize= plot_args.get('figsize', None)
15561556
title = plot_args.get('title', results[0].name)
15571557
title_fontsize = plot_args.get('title_fontsize', None)
1558-
xlabel = plot_args.get('xlabel', 'X')
1558+
xlabel = plot_args.get('xlabel', '')
15591559
xlabel_fontsize = plot_args.get('xlabel_fontsize', None)
15601560
xticks_fontsize = plot_args.get('xticks_fontsize', None)
15611561
ylabel_fontsize = plot_args.get('ylabel_fontsize', None)
@@ -1565,6 +1565,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15651565
hbars = plot_args.get('hbars', True)
15661566
tight_layout = plot_args.get('tight_layout', True)
15671567
percentile = plot_args.get('percentile', 95)
1568+
plot_mean = plot_args.get('mean', False)
15681569

15691570
if axes is None:
15701571
fig, ax = pyplot.subplots(figsize=figsize)
@@ -1578,6 +1579,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15781579
if res.test_distribution[0] == 'poisson':
15791580
plow = scipy.stats.poisson.ppf((1 - percentile/100.)/2., res.test_distribution[1])
15801581
phigh = scipy.stats.poisson.ppf(1 - (1 - percentile/100.)/2., res.test_distribution[1])
1582+
mean = res.test_distribution[1]
15811583
observed_statistic = res.observed_statistic
15821584
# empirical distributions
15831585
else:
@@ -1594,12 +1596,14 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15941596
else:
15951597
plow = numpy.percentile(test_distribution, (100 - percentile)/2.)
15961598
phigh = numpy.percentile(test_distribution, 100 - (100 - percentile)/2.)
1599+
mean = numpy.mean(res.test_distribution)
15971600

15981601
if not numpy.isinf(observed_statistic): # Check if test result does not diverges
1599-
low = observed_statistic - plow
1600-
high = phigh - observed_statistic
1601-
ax.errorbar(observed_statistic, index, xerr=numpy.array([[low, high]]).T,
1602-
fmt=_get_marker_style(observed_statistic, (plow, phigh), one_sided_lower),
1602+
percentile_lims = numpy.array([[mean - plow, phigh - mean]]).T
1603+
ax.plot(observed_statistic, index,
1604+
_get_marker_style(observed_statistic, (plow, phigh), one_sided_lower))
1605+
ax.errorbar(mean, index, xerr=percentile_lims,
1606+
fmt='ko'*plot_mean,
16031607
capsize=capsize, linewidth=linewidth, ecolor=color)
16041608
# determine the limits to use
16051609
xlims.append((plow, phigh, observed_statistic))
@@ -1883,8 +1887,9 @@ def add_labels_for_publication(figure, style='bssa', labelsize=16):
18831887
ax.annotate(f'({annot})', (0.025, 1.025), xycoords='axes fraction', fontsize=labelsize)
18841888

18851889
return
1886-
1887-
def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, plot_args=None, variance=None):
1890+
1891+
1892+
def plot_consistency_test(eval_results, normalize=False, axes=None, one_sided_lower=False, variance=None, plot_args=None, show=False):
18881893
""" Plots results from CSEP1 tests following the CSEP1 convention.
18891894
18901895
Note: All of the evaluations should be from the same type of evaluation, otherwise the results will not be
@@ -1924,8 +1929,10 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19241929
# Parse plot arguments. More can be added here
19251930
if plot_args is None:
19261931
plot_args = {}
1927-
figsize= plot_args.get('figsize', (7,8))
1928-
xlabel = plot_args.get('xlabel', 'X')
1932+
figsize= plot_args.get('figsize', None)
1933+
title = plot_args.get('title', results[0].name)
1934+
title_fontsize = plot_args.get('title_fontsize', None)
1935+
xlabel = plot_args.get('xlabel', '')
19291936
xlabel_fontsize = plot_args.get('xlabel_fontsize', None)
19301937
xticks_fontsize = plot_args.get('xticks_fontsize', None)
19311938
ylabel_fontsize = plot_args.get('ylabel_fontsize', None)
@@ -1935,15 +1942,22 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19351942
hbars = plot_args.get('hbars', True)
19361943
tight_layout = plot_args.get('tight_layout', True)
19371944
percentile = plot_args.get('percentile', 95)
1945+
plot_mean = plot_args.get('mean', False)
1946+
1947+
if axes is None:
1948+
fig, ax = pyplot.subplots(figsize=figsize)
1949+
else:
1950+
ax = axes
1951+
fig = ax.get_figure()
19381952

1939-
fig, ax = pyplot.subplots(figsize=figsize)
19401953
xlims = []
19411954

19421955
for index, res in enumerate(results):
19431956
# handle analytical distributions first, they are all in the form ['name', parameters].
19441957
if res.test_distribution[0] == 'poisson':
19451958
plow = scipy.stats.poisson.ppf((1 - percentile/100.)/2., res.test_distribution[1])
19461959
phigh = scipy.stats.poisson.ppf(1 - (1 - percentile/100.)/2., res.test_distribution[1])
1960+
mean = res.test_distribution[1]
19471961
observed_statistic = res.observed_statistic
19481962

19491963
elif res.test_distribution[0] == 'negative_binomial':
@@ -1970,13 +1984,15 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19701984
else:
19711985
plow = numpy.percentile(test_distribution, 2.5)
19721986
phigh = numpy.percentile(test_distribution, 97.5)
1987+
mean = numpy.mean(res.test_distribution)
19731988

19741989
if not numpy.isinf(observed_statistic): # Check if test result does not diverges
1975-
low = observed_statistic - plow
1976-
high = phigh - observed_statistic
1977-
ax.errorbar(observed_statistic, index, xerr=numpy.array([[low, high]]).T,
1978-
fmt=_get_marker_style(observed_statistic, (plow, phigh), one_sided_lower),
1979-
capsize=4, linewidth=linewidth, ecolor=color, markersize = 10, zorder=1)
1990+
percentile_lims = numpy.array([[mean - plow, phigh - mean]]).T
1991+
ax.plot(observed_statistic, index,
1992+
_get_marker_style(observed_statistic, (plow, phigh), one_sided_lower))
1993+
ax.errorbar(mean, index, xerr=percentile_lims,
1994+
fmt='ko'*plot_mean,
1995+
capsize=capsize, linewidth=linewidth, ecolor=color)
19801996
# determine the limits to use
19811997
xlims.append((plow, phigh, observed_statistic))
19821998
# we want to only extent the distribution where it falls outside of it in the acceptable tail
@@ -1998,18 +2014,23 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19982014
except ValueError:
19992015
raise ValueError('All EvaluationResults have infinite observed_statistics')
20002016
ax.set_yticks(numpy.arange(len(results)))
2001-
ax.set_yticklabels([res.sim_name for res in results], fontsize=14)
2017+
ax.set_yticklabels([res.sim_name for res in results], fontsize=ylabel_fontsize)
20022018
ax.set_ylim([-0.5, len(results)-0.5])
20032019
if hbars:
20042020
yTickPos = ax.get_yticks()
20052021
if len(yTickPos) >= 2:
20062022
ax.barh(yTickPos, numpy.array([99999] * len(yTickPos)), left=-10000,
20072023
height=(yTickPos[1] - yTickPos[0]), color=['w', 'gray'], alpha=0.2, zorder=0)
2008-
ax.set_xlabel(xlabel, fontsize=14)
2009-
ax.tick_params(axis='x', labelsize=13)
2024+
ax.set_title(title, fontsize=title_fontsize)
2025+
ax.set_xlabel(xlabel, fontsize=xlabel_fontsize)
2026+
ax.tick_params(axis='x', labelsize=xticks_fontsize)
20102027
if tight_layout:
20112028
ax.figure.tight_layout()
20122029
fig.tight_layout()
2030+
2031+
if show:
2032+
pyplot.show()
2033+
20132034
return ax
20142035

20152036

0 commit comments

Comments
 (0)