forked from steggema/cms-das-tau
-
Notifications
You must be signed in to change notification settings - Fork 0
/
roc_tools.py
140 lines (109 loc) · 4.03 KB
/
roc_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import print_function
import ROOT
from ROOT import gROOT, gStyle, TEfficiency
from cms_style import cms_style
cms_style(gStyle)
hex_colours = ['#7fc97f','#beaed4','#fdc086','#386cb0','#f0027f','#bf5b17','#666666','#ffff99'] # dark
colours = [ROOT.TColor.GetColor(hex) for hex in hex_colours]
markers = [20, 21, 22, 23, 24, 25, 26, 27]
def hists_to_roc(hsig, hbg, w_error=False):
'''Produce ROC curve from 2 input histograms hsig and hbg.
Partly adapted from Giovanni's ttH code.
'''
nbins = hsig.GetNbinsX() + 2 # include under/overflow; remove events not passing selection
si = [hsig.GetBinContent(i) for i in range(nbins)]
bi = [hbg.GetBinContent(i) for i in range(nbins)]
if hsig.GetMean() > hbg.GetMean():
si.reverse()
bi.reverse()
sums, sumb = sum(si), sum(bi)
if sums == 0 or sumb == 0:
print('WARNING: Either signal or background histogram empty', sums, sumb)
return None
# make cumulative
for i in range(1, nbins):
si[i] += si[i - 1]
bi[i] += bi[i - 1]
fullsi, fullbi = si[:], bi[:]
si, bi = [], []
for i in range(1, nbins):
# skip negative weights
if si and (fullsi[i] < si[-1] or fullbi[i] < bi[-1]):
continue
# skip repetitions
if fullsi[i] != fullsi[i - 1] or fullbi[i] != fullbi[i - 1]:
si.append(fullsi[i])
bi.append(fullbi[i])
# Remove the trivial (1, 1) points
si.pop()
bi.pop()
if len(si) == 2:
si = [si[0]]
bi = [bi[0]]
bins = len(si)
if not w_error:
roc = ROOT.TGraph(bins)
for i in range(bins):
roc.SetPoint(i, si[i] / sums, bi[i] / sumb)
return roc
roc = ROOT.TGraphAsymmErrors(bins)
for i in range(bins):
interval = 0.683
e_s_low = si[i] / sums - TEfficiency.ClopperPearson(sums, si[i], interval, False)
e_s_up = TEfficiency.ClopperPearson(sums, si[i], interval, True) - si[i] / sums
e_b_low = bi[i] / sumb - TEfficiency.ClopperPearson(sumb, bi[i], interval, False)
e_b_up = TEfficiency.ClopperPearson(sumb, bi[i], interval, True) - bi[i] / sumb
roc.SetPoint(i, si[i] / sums, bi[i] / sumb)
roc.SetPointError(i, e_s_low, e_s_up, e_b_low, e_b_up)
return roc
def make_legend(rocs, textSize=0.035, left=True):
(x1, y1, x2, y2) = (.18 if left else .68, .76 - textSize * max(len(rocs) - 3, 0), .4 if left else .95, .88)
leg = ROOT.TLegend(x1, y1, x2, y2)
leg.SetFillColor(0)
leg.SetShadowColor(0)
leg.SetLineColor(0)
leg.SetLineWidth(0)
leg.SetTextFont(42)
leg.SetTextSize(textSize)
for key, roc in rocs:
leg.AddEntry(roc, key, 'lp')
leg.Draw()
return leg
def make_roc_plot(rocs, set_name='rocs', ymin=0., ymax=1., xmin=0., xmax=1., logy=False, formats=[]):
'''Plots multiple ROC curves (TGraph derivatives)
'''
allrocs = ROOT.TMultiGraph(set_name, '')
point_graphs = []
i_marker = 0
for i_col, graph in enumerate(rocs):
col = colours[i_col]
graph.SetLineColor(col)
graph.SetMarkerColor(col)
graph.SetLineWidth(3)
graph.SetMarkerStyle(0)
if graph.GetN() > 10:
allrocs.Add(graph)
else:
graph.SetMarkerStyle(markers[i_marker])
i_marker += 1
graph.SetMarkerSize(1)
point_graphs.append(graph)
c = ROOT.TCanvas()
allrocs.Draw('APL')
allrocs.GetXaxis().SetTitle('#epsilon_{s}')
allrocs.GetYaxis().SetTitle('#epsilon_{b}')
allrocs.GetYaxis().SetDecimals(True)
allrocs.GetYaxis().SetRangeUser(ymin, ymax)
allrocs.GetXaxis().SetRangeUser(xmin, xmax)
if logy:
if ymin > 0.:
c.SetLogy()
else:
print('Cannot set logarithmic y axis if minimum y value is <= 0.')
allrocs.Draw('APL')
for graph in point_graphs:
graph.Draw('P')
allrocs.leg = make_legend(list(zip([r.title for r in rocs], rocs)))
for f in formats:
c.Print(set_name + '.' + f)
return allrocs, c