@@ -1557,7 +1557,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15571557 figsize = plot_args .get ('figsize' , None )
15581558 title = plot_args .get ('title' , results [0 ].name )
15591559 title_fontsize = plot_args .get ('title_fontsize' , None )
1560- xlabel = plot_args .get ('xlabel' , 'X ' )
1560+ xlabel = plot_args .get ('xlabel' , '' )
15611561 xlabel_fontsize = plot_args .get ('xlabel_fontsize' , None )
15621562 xticks_fontsize = plot_args .get ('xticks_fontsize' , None )
15631563 ylabel_fontsize = plot_args .get ('ylabel_fontsize' , None )
@@ -1567,6 +1567,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15671567 hbars = plot_args .get ('hbars' , True )
15681568 tight_layout = plot_args .get ('tight_layout' , True )
15691569 percentile = plot_args .get ('percentile' , 95 )
1570+ plot_mean = plot_args .get ('mean' , False )
15701571
15711572 if axes is None :
15721573 fig , ax = pyplot .subplots (figsize = figsize )
@@ -1580,6 +1581,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15801581 if res .test_distribution [0 ] == 'poisson' :
15811582 plow = scipy .stats .poisson .ppf ((1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
15821583 phigh = scipy .stats .poisson .ppf (1 - (1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
1584+ mean = res .test_distribution [1 ]
15831585 observed_statistic = res .observed_statistic
15841586 # empirical distributions
15851587 else :
@@ -1596,12 +1598,14 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15961598 else :
15971599 plow = numpy .percentile (test_distribution , (100 - percentile )/ 2. )
15981600 phigh = numpy .percentile (test_distribution , 100 - (100 - percentile )/ 2. )
1601+ mean = numpy .mean (res .test_distribution )
15991602
16001603 if not numpy .isinf (observed_statistic ): # Check if test result does not diverges
1601- low = observed_statistic - plow
1602- high = phigh - observed_statistic
1603- ax .errorbar (observed_statistic , index , xerr = numpy .array ([[low , high ]]).T ,
1604- fmt = _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ),
1604+ percentile_lims = numpy .array ([[mean - plow , phigh - mean ]]).T
1605+ ax .plot (observed_statistic , index ,
1606+ _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ))
1607+ ax .errorbar (mean , index , xerr = percentile_lims ,
1608+ fmt = 'ko' * plot_mean ,
16051609 capsize = capsize , linewidth = linewidth , ecolor = color )
16061610 # determine the limits to use
16071611 xlims .append ((plow , phigh , observed_statistic ))
@@ -1887,7 +1891,7 @@ def add_labels_for_publication(figure, style='bssa', labelsize=16):
18871891 return
18881892
18891893
1890- def plot_consistency_test (eval_results , normalize = False , one_sided_lower = True , plot_args = None , variance = None ):
1894+ def plot_consistency_test (eval_results , normalize = False , axes = None , one_sided_lower = False , variance = None , plot_args = None , show = False ):
18911895 """ Plots results from CSEP1 tests following the CSEP1 convention.
18921896
18931897 Note: All of the evaluations should be from the same type of evaluation, otherwise the results will not be
@@ -1927,8 +1931,10 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19271931 # Parse plot arguments. More can be added here
19281932 if plot_args is None :
19291933 plot_args = {}
1930- figsize = plot_args .get ('figsize' , (7 ,8 ))
1931- xlabel = plot_args .get ('xlabel' , 'X' )
1934+ figsize = plot_args .get ('figsize' , None )
1935+ title = plot_args .get ('title' , results [0 ].name )
1936+ title_fontsize = plot_args .get ('title_fontsize' , None )
1937+ xlabel = plot_args .get ('xlabel' , '' )
19321938 xlabel_fontsize = plot_args .get ('xlabel_fontsize' , None )
19331939 xticks_fontsize = plot_args .get ('xticks_fontsize' , None )
19341940 ylabel_fontsize = plot_args .get ('ylabel_fontsize' , None )
@@ -1938,15 +1944,22 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19381944 hbars = plot_args .get ('hbars' , True )
19391945 tight_layout = plot_args .get ('tight_layout' , True )
19401946 percentile = plot_args .get ('percentile' , 95 )
1947+ plot_mean = plot_args .get ('mean' , False )
1948+
1949+ if axes is None :
1950+ fig , ax = pyplot .subplots (figsize = figsize )
1951+ else :
1952+ ax = axes
1953+ fig = ax .get_figure ()
19411954
1942- fig , ax = pyplot .subplots (figsize = figsize )
19431955 xlims = []
19441956
19451957 for index , res in enumerate (results ):
19461958 # handle analytical distributions first, they are all in the form ['name', parameters].
19471959 if res .test_distribution [0 ] == 'poisson' :
19481960 plow = scipy .stats .poisson .ppf ((1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
19491961 phigh = scipy .stats .poisson .ppf (1 - (1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
1962+ mean = res .test_distribution [1 ]
19501963 observed_statistic = res .observed_statistic
19511964
19521965 elif res .test_distribution [0 ] == 'negative_binomial' :
@@ -1973,13 +1986,15 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19731986 else :
19741987 plow = numpy .percentile (test_distribution , 2.5 )
19751988 phigh = numpy .percentile (test_distribution , 97.5 )
1989+ mean = numpy .mean (res .test_distribution )
19761990
19771991 if not numpy .isinf (observed_statistic ): # Check if test result does not diverges
1978- low = observed_statistic - plow
1979- high = phigh - observed_statistic
1980- ax .errorbar (observed_statistic , index , xerr = numpy .array ([[low , high ]]).T ,
1981- fmt = _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ),
1982- capsize = 4 , linewidth = linewidth , ecolor = color , markersize = 10 , zorder = 1 )
1992+ percentile_lims = numpy .array ([[mean - plow , phigh - mean ]]).T
1993+ ax .plot (observed_statistic , index ,
1994+ _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ))
1995+ ax .errorbar (mean , index , xerr = percentile_lims ,
1996+ fmt = 'ko' * plot_mean ,
1997+ capsize = capsize , linewidth = linewidth , ecolor = color )
19831998 # determine the limits to use
19841999 xlims .append ((plow , phigh , observed_statistic ))
19852000 # we want to only extent the distribution where it falls outside of it in the acceptable tail
@@ -2001,18 +2016,23 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
20012016 except ValueError :
20022017 raise ValueError ('All EvaluationResults have infinite observed_statistics' )
20032018 ax .set_yticks (numpy .arange (len (results )))
2004- ax .set_yticklabels ([res .sim_name for res in results ], fontsize = 14 )
2019+ ax .set_yticklabels ([res .sim_name for res in results ], fontsize = ylabel_fontsize )
20052020 ax .set_ylim ([- 0.5 , len (results )- 0.5 ])
20062021 if hbars :
20072022 yTickPos = ax .get_yticks ()
20082023 if len (yTickPos ) >= 2 :
20092024 ax .barh (yTickPos , numpy .array ([99999 ] * len (yTickPos )), left = - 10000 ,
20102025 height = (yTickPos [1 ] - yTickPos [0 ]), color = ['w' , 'gray' ], alpha = 0.2 , zorder = 0 )
2011- ax .set_xlabel (xlabel , fontsize = 14 )
2012- ax .tick_params (axis = 'x' , labelsize = 13 )
2026+ ax .set_title (title , fontsize = title_fontsize )
2027+ ax .set_xlabel (xlabel , fontsize = xlabel_fontsize )
2028+ ax .tick_params (axis = 'x' , labelsize = xticks_fontsize )
20132029 if tight_layout :
20142030 ax .figure .tight_layout ()
20152031 fig .tight_layout ()
2032+
2033+ if show :
2034+ pyplot .show ()
2035+
20162036 return ax
20172037
20182038
0 commit comments