@@ -1555,7 +1555,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15551555 figsize = plot_args .get ('figsize' , None )
15561556 title = plot_args .get ('title' , results [0 ].name )
15571557 title_fontsize = plot_args .get ('title_fontsize' , None )
1558- xlabel = plot_args .get ('xlabel' , 'X ' )
1558+ xlabel = plot_args .get ('xlabel' , '' )
15591559 xlabel_fontsize = plot_args .get ('xlabel_fontsize' , None )
15601560 xticks_fontsize = plot_args .get ('xticks_fontsize' , None )
15611561 ylabel_fontsize = plot_args .get ('ylabel_fontsize' , None )
@@ -1565,6 +1565,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15651565 hbars = plot_args .get ('hbars' , True )
15661566 tight_layout = plot_args .get ('tight_layout' , True )
15671567 percentile = plot_args .get ('percentile' , 95 )
1568+ plot_mean = plot_args .get ('mean' , False )
15681569
15691570 if axes is None :
15701571 fig , ax = pyplot .subplots (figsize = figsize )
@@ -1578,6 +1579,7 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15781579 if res .test_distribution [0 ] == 'poisson' :
15791580 plow = scipy .stats .poisson .ppf ((1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
15801581 phigh = scipy .stats .poisson .ppf (1 - (1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
1582+ mean = res .test_distribution [1 ]
15811583 observed_statistic = res .observed_statistic
15821584 # empirical distributions
15831585 else :
@@ -1594,12 +1596,14 @@ def plot_poisson_consistency_test(eval_results, normalize=False, one_sided_lower
15941596 else :
15951597 plow = numpy .percentile (test_distribution , (100 - percentile )/ 2. )
15961598 phigh = numpy .percentile (test_distribution , 100 - (100 - percentile )/ 2. )
1599+ mean = numpy .mean (res .test_distribution )
15971600
15981601 if not numpy .isinf (observed_statistic ): # Check if test result does not diverges
1599- low = observed_statistic - plow
1600- high = phigh - observed_statistic
1601- ax .errorbar (observed_statistic , index , xerr = numpy .array ([[low , high ]]).T ,
1602- fmt = _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ),
1602+ percentile_lims = numpy .array ([[mean - plow , phigh - mean ]]).T
1603+ ax .plot (observed_statistic , index ,
1604+ _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ))
1605+ ax .errorbar (mean , index , xerr = percentile_lims ,
1606+ fmt = 'ko' * plot_mean ,
16031607 capsize = capsize , linewidth = linewidth , ecolor = color )
16041608 # determine the limits to use
16051609 xlims .append ((plow , phigh , observed_statistic ))
@@ -1883,8 +1887,9 @@ def add_labels_for_publication(figure, style='bssa', labelsize=16):
18831887 ax .annotate (f'({ annot } )' , (0.025 , 1.025 ), xycoords = 'axes fraction' , fontsize = labelsize )
18841888
18851889 return
1886-
1887- def plot_consistency_test (eval_results , normalize = False , one_sided_lower = True , plot_args = None , variance = None ):
1890+
1891+
1892+ def plot_consistency_test (eval_results , normalize = False , axes = None , one_sided_lower = False , variance = None , plot_args = None , show = False ):
18881893 """ Plots results from CSEP1 tests following the CSEP1 convention.
18891894
18901895 Note: All of the evaluations should be from the same type of evaluation, otherwise the results will not be
@@ -1924,8 +1929,10 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19241929 # Parse plot arguments. More can be added here
19251930 if plot_args is None :
19261931 plot_args = {}
1927- figsize = plot_args .get ('figsize' , (7 ,8 ))
1928- xlabel = plot_args .get ('xlabel' , 'X' )
1932+ figsize = plot_args .get ('figsize' , None )
1933+ title = plot_args .get ('title' , results [0 ].name )
1934+ title_fontsize = plot_args .get ('title_fontsize' , None )
1935+ xlabel = plot_args .get ('xlabel' , '' )
19291936 xlabel_fontsize = plot_args .get ('xlabel_fontsize' , None )
19301937 xticks_fontsize = plot_args .get ('xticks_fontsize' , None )
19311938 ylabel_fontsize = plot_args .get ('ylabel_fontsize' , None )
@@ -1935,15 +1942,22 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19351942 hbars = plot_args .get ('hbars' , True )
19361943 tight_layout = plot_args .get ('tight_layout' , True )
19371944 percentile = plot_args .get ('percentile' , 95 )
1945+ plot_mean = plot_args .get ('mean' , False )
1946+
1947+ if axes is None :
1948+ fig , ax = pyplot .subplots (figsize = figsize )
1949+ else :
1950+ ax = axes
1951+ fig = ax .get_figure ()
19381952
1939- fig , ax = pyplot .subplots (figsize = figsize )
19401953 xlims = []
19411954
19421955 for index , res in enumerate (results ):
19431956 # handle analytical distributions first, they are all in the form ['name', parameters].
19441957 if res .test_distribution [0 ] == 'poisson' :
19451958 plow = scipy .stats .poisson .ppf ((1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
19461959 phigh = scipy .stats .poisson .ppf (1 - (1 - percentile / 100. )/ 2. , res .test_distribution [1 ])
1960+ mean = res .test_distribution [1 ]
19471961 observed_statistic = res .observed_statistic
19481962
19491963 elif res .test_distribution [0 ] == 'negative_binomial' :
@@ -1970,13 +1984,15 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19701984 else :
19711985 plow = numpy .percentile (test_distribution , 2.5 )
19721986 phigh = numpy .percentile (test_distribution , 97.5 )
1987+ mean = numpy .mean (res .test_distribution )
19731988
19741989 if not numpy .isinf (observed_statistic ): # Check if test result does not diverges
1975- low = observed_statistic - plow
1976- high = phigh - observed_statistic
1977- ax .errorbar (observed_statistic , index , xerr = numpy .array ([[low , high ]]).T ,
1978- fmt = _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ),
1979- capsize = 4 , linewidth = linewidth , ecolor = color , markersize = 10 , zorder = 1 )
1990+ percentile_lims = numpy .array ([[mean - plow , phigh - mean ]]).T
1991+ ax .plot (observed_statistic , index ,
1992+ _get_marker_style (observed_statistic , (plow , phigh ), one_sided_lower ))
1993+ ax .errorbar (mean , index , xerr = percentile_lims ,
1994+ fmt = 'ko' * plot_mean ,
1995+ capsize = capsize , linewidth = linewidth , ecolor = color )
19801996 # determine the limits to use
19811997 xlims .append ((plow , phigh , observed_statistic ))
19821998 # we want to only extent the distribution where it falls outside of it in the acceptable tail
@@ -1998,18 +2014,23 @@ def plot_consistency_test(eval_results, normalize=False, one_sided_lower=True, p
19982014 except ValueError :
19992015 raise ValueError ('All EvaluationResults have infinite observed_statistics' )
20002016 ax .set_yticks (numpy .arange (len (results )))
2001- ax .set_yticklabels ([res .sim_name for res in results ], fontsize = 14 )
2017+ ax .set_yticklabels ([res .sim_name for res in results ], fontsize = ylabel_fontsize )
20022018 ax .set_ylim ([- 0.5 , len (results )- 0.5 ])
20032019 if hbars :
20042020 yTickPos = ax .get_yticks ()
20052021 if len (yTickPos ) >= 2 :
20062022 ax .barh (yTickPos , numpy .array ([99999 ] * len (yTickPos )), left = - 10000 ,
20072023 height = (yTickPos [1 ] - yTickPos [0 ]), color = ['w' , 'gray' ], alpha = 0.2 , zorder = 0 )
2008- ax .set_xlabel (xlabel , fontsize = 14 )
2009- ax .tick_params (axis = 'x' , labelsize = 13 )
2024+ ax .set_title (title , fontsize = title_fontsize )
2025+ ax .set_xlabel (xlabel , fontsize = xlabel_fontsize )
2026+ ax .tick_params (axis = 'x' , labelsize = xticks_fontsize )
20102027 if tight_layout :
20112028 ax .figure .tight_layout ()
20122029 fig .tight_layout ()
2030+
2031+ if show :
2032+ pyplot .show ()
2033+
20132034 return ax
20142035
20152036
0 commit comments