Skip to content

Commit ff86c30

Browse files
committed
[RF] Implement SumW2 correction in new BatchMode with RooFitDriver
This has not been implemented so far. This commit also includes a unit test for it. For easier toggling of squared weights, a new virtual function `RooAbsArg::applyWeightsSquared` was introduced such that one doesn't have to pick up manully the likelihood classes from the computation graph when applying the weights squared correction.
1 parent 257bc91 commit ff86c30

File tree

10 files changed

+143
-73
lines changed

10 files changed

+143
-73
lines changed

roofit/roofitcore/inc/RooAbsArg.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,8 @@ class RooAbsArg : public TNamed, public RooPrintable {
599599
virtual bool canComputeBatchWithCuda() const { return false; }
600600
virtual bool isReducerNode() const { return false; }
601601

602+
virtual void applyWeightSquared(bool flag);
603+
602604
operator RooBatchCompute::DataKey() const { return RooBatchCompute::DataKey::create(this->namePtr()); }
603605

604606
protected:

roofit/roofitcore/inc/RooAbsPdf.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class RooAbsPdf : public RooAbsReal {
406406

407407
private:
408408
int calcAsymptoticCorrectedCovariance(RooMinimizer& minimizer, RooAbsData const& data);
409-
int calcSumW2CorrectedCovariance(RooMinimizer& minimizer, RooAbsReal const& nll) const;
409+
int calcSumW2CorrectedCovariance(RooMinimizer& minimizer, RooAbsReal & nll) const;
410410

411411
ClassDefOverride(RooAbsPdf,5) // Abstract PDF with normalization support
412412
};

roofit/roofitcore/inc/RooNLLVar.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class RooNLLVar : public RooAbsOptTestStatistic {
5353

5454
~RooNLLVar() override;
5555

56-
void applyWeightSquared(Bool_t flag) ;
56+
void applyWeightSquared(bool flag) override;
5757

5858
Double_t defaultErrorLevel() const override { return 0.5 ; }
5959

roofit/roofitcore/res/RooFitDriver.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ class RooFitDriver {
127127

128128
double getValV(const RooArgSet *) const override { return evaluate(); }
129129

130+
void applyWeightSquared(bool flag) override
131+
{
132+
const_cast<RooAbsReal &>(_driver->topNode()).applyWeightSquared(flag);
133+
}
134+
130135
protected:
131136
double evaluate() const override { return _driver ? _driver->getVal() : 0.0; }
132137

roofit/roofitcore/res/RooNLLVarNew.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ namespace Experimental {
2626
class RooNLLVarNew : public RooAbsReal {
2727

2828
public:
29+
// The names for the weight variables that the RooNLLVarNew expects
30+
static constexpr auto weightVarName = "_weight";
31+
static constexpr auto weightVarNameSumW2 = "_weight_sumW2";
32+
static constexpr auto weightVarNameSumW2Suffix = "_sumW2";
33+
2934
RooNLLVarNew(){};
3035
RooNLLVarNew(const char *name, const char *title, RooAbsPdf &pdf, RooArgSet const &observables, RooAbsReal *weight,
3136
bool isExtended, std::string const &rangeName);
@@ -41,13 +46,17 @@ class RooNLLVarNew : public RooAbsReal {
4146
void computeBatch(cudaStream_t *, double *output, size_t nOut, RooBatchCompute::DataMap &) const override;
4247
inline bool isReducerNode() const override { return true; }
4348

49+
RooArgSet prefixObservableAndWeightNames(std::string const &prefix);
50+
51+
void applyWeightSquared(bool flag) override;
52+
53+
protected:
4454
void setObservables(RooArgSet const &observables)
4555
{
4656
_observables.clear();
4757
_observables.add(observables);
4858
}
4959

50-
protected:
5160
RooTemplateProxy<RooAbsPdf> _pdf;
5261
RooArgSet _observables;
5362
std::unique_ptr<RooTemplateProxy<RooAbsReal>> _weight;

roofit/roofitcore/src/BatchModeHelpers.cxx

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <string>
2929

30+
using ROOT::Experimental::RooFitDriver;
31+
using ROOT::Experimental::RooNLLVarNew;
32+
3033
namespace {
3134

3235
std::unique_ptr<RooAbsArg> prepareSimultaneousModelForBatchMode(RooSimultaneous &simPdf, RooArgSet &observables,
@@ -40,8 +43,8 @@ std::unique_ptr<RooAbsArg> prepareSimultaneousModelForBatchMode(RooSimultaneous
4043
auto const &catName = catItem.first;
4144
auto *pdf = simPdf.getPdf(catName.c_str());
4245
auto nllName = std::string("nll_") + pdf->GetName();
43-
nllTerms.add(*new ROOT::Experimental::RooNLLVarNew(nllName.c_str(), nllName.c_str(), *pdf, observables, weight,
44-
isExtended, rangeName));
46+
nllTerms.add(
47+
*new RooNLLVarNew(nllName.c_str(), nllName.c_str(), *pdf, observables, weight, isExtended, rangeName));
4548
}
4649

4750
RooArgSet newObservables;
@@ -50,22 +53,8 @@ std::unique_ptr<RooAbsArg> prepareSimultaneousModelForBatchMode(RooSimultaneous
5053
std::size_t iNLL = 0;
5154
for (auto const &catItem : simPdf.indexCat()) {
5255
auto const &catName = catItem.first;
53-
auto &nll = nllTerms[iNLL];
54-
RooArgSet pdfObs;
55-
nll.getObservables(&observables, pdfObs);
56-
if (weight)
57-
pdfObs.add(*weight);
58-
RooArgSet obsClones;
59-
pdfObs.snapshot(obsClones);
60-
for (RooAbsArg *arg : obsClones) {
61-
auto newName = std::string("_") + catName + "_" + arg->GetName();
62-
arg->setAttribute((std::string("ORIGNAME:") + arg->GetName()).c_str());
63-
arg->SetName(newName.c_str());
64-
}
65-
nll.recursiveRedirectServers(obsClones, false, true);
66-
newObservables.add(obsClones);
67-
static_cast<ROOT::Experimental::RooNLLVarNew &>(nll).setObservables(obsClones);
68-
nll.addOwnedComponents(std::move(obsClones));
56+
auto &nll = static_cast<RooNLLVarNew &>(nllTerms[iNLL]);
57+
newObservables.add(nll.prefixObservableAndWeightNames(std::string("_") + catName + "_"));
6958
++iNLL;
7059
}
7160

@@ -86,7 +75,7 @@ RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::uniqu
8675
{
8776
std::unique_ptr<RooRealVar> weightVar;
8877

89-
std::unique_ptr<ROOT::Experimental::RooFitDriver> driver;
78+
std::unique_ptr<RooFitDriver> driver;
9079

9180
RooArgSet observables;
9281
pdf.getObservables(data.get(), observables);
@@ -104,16 +93,9 @@ RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::uniqu
10493
}
10594

10695
if (data.isWeighted()) {
107-
std::string weightVarName = "_weight";
108-
if (auto *dataSet = dynamic_cast<RooDataSet const *>(&data)) {
109-
if (dataSet->weightVar())
110-
weightVarName = dataSet->weightVar()->GetName();
111-
}
112-
113-
// make a clone of the weight variable (or an initial instance, if it doesn't exist)
114-
// the clone will hold the weight value (or values as a batch) and will participate
115-
// in the computation graph of the RooFit driver.
116-
weightVar = std::make_unique<RooRealVar>(weightVarName.c_str(), "Weight(s) of events", data.weight());
96+
// RooRealVar for the weight value (or values as a batch) that will
97+
// participate in the computation graph of the RooFit driver.
98+
weightVar = std::make_unique<RooRealVar>(RooNLLVarNew::weightVarName, "Weight(s) of events", data.weight());
11799
}
118100

119101
// Deal with the IntegrateBins argument
@@ -135,8 +117,8 @@ RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::uniqu
135117
nllTerms.addOwned(
136118
prepareSimultaneousModelForBatchMode(*simPdfClone, observables, weightVar.get(), isExtended, rangeName));
137119
} else {
138-
nllTerms.addOwned(std::make_unique<ROOT::Experimental::RooNLLVarNew>(
139-
"RooNLLVarNew", "RooNLLVarNew", finalPdf, observables, weightVar.get(), isExtended, rangeName));
120+
nllTerms.addOwned(std::make_unique<RooNLLVarNew>("RooNLLVarNew", "RooNLLVarNew", finalPdf, observables,
121+
weightVar.get(), isExtended, rangeName));
140122
}
141123
if (constraints) {
142124
nllTerms.addOwned(std::move(constraints));
@@ -151,10 +133,9 @@ RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::uniqu
151133
RooArgSet parameters;
152134
pdf.getParameters(data.get(), parameters);
153135
nll->recursiveRedirectServers(parameters);
154-
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, batchMode, rangeName,
155-
&simPdf->indexCat());
136+
driver = std::make_unique<RooFitDriver>(data, *nll, observables, batchMode, rangeName, &simPdf->indexCat());
156137
} else {
157-
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, batchMode, rangeName);
138+
driver = std::make_unique<RooFitDriver>(data, *nll, observables, batchMode, rangeName);
158139
}
159140

160141
// Set the fitrange attribute so that RooPlot can automatically plot the fitting range by default
@@ -179,7 +160,7 @@ RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::uniqu
179160
pdf.setStringAttribute("fitrange", fitrangeValue.c_str());
180161
}
181162

182-
auto driverWrapper = ROOT::Experimental::RooFitDriver::makeAbsRealWrapper(std::move(driver));
163+
auto driverWrapper = RooFitDriver::makeAbsRealWrapper(std::move(driver));
183164
driverWrapper->addOwnedComponents(std::move(nll));
184165
if (weightVar)
185166
driverWrapper->addOwnedComponents(std::move(weightVar));

roofit/roofitcore/src/RooAbsArg.cxx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,3 +2545,12 @@ std::string printValue(RooAbsArg *raa)
25452545
return s.str();
25462546
}
25472547
} // namespace cling
2548+
2549+
2550+
/// Disables or enables the usage of squared weights. Needs to be overloaded in
2551+
/// the likelihood classes for which this is relevant.
2552+
void RooAbsArg::applyWeightSquared(bool flag) {
2553+
for(auto * server : servers()) {
2554+
server->applyWeightSquared(flag);
2555+
}
2556+
}

roofit/roofitcore/src/RooAbsPdf.cxx

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,31 +1309,17 @@ int RooAbsPdf::calcAsymptoticCorrectedCovariance(RooMinimizer &minimizer, RooAbs
13091309
/// matrix caltulated here will be applied to it via
13101310
/// RooMinimizer::applyCovarianceMatrix().
13111311
/// \param[in] nll The NLL object that was used for the fit.
1312-
int RooAbsPdf::calcSumW2CorrectedCovariance(RooMinimizer &minimizer, RooAbsReal const &nll) const
1312+
int RooAbsPdf::calcSumW2CorrectedCovariance(RooMinimizer &minimizer, RooAbsReal &nll) const
13131313
{
1314-
1315-
// Make list of RooNLLVar components of FCN
1316-
std::vector<RooNLLVar *> nllComponents;
1317-
std::unique_ptr<RooArgSet> comps{nll.getComponents()};
1318-
for (auto const &arg : *comps) {
1319-
if (RooNLLVar *nllComp = dynamic_cast<RooNLLVar *>(arg)) {
1320-
nllComponents.push_back(nllComp);
1321-
}
1322-
}
1323-
13241314
// Calculated corrected errors for weighted likelihood fits
13251315
std::unique_ptr<RooFitResult> rw{minimizer.save()};
1326-
for (auto &comp : nllComponents) {
1327-
comp->applyWeightSquared(true);
1328-
}
1316+
nll.applyWeightSquared(true);
13291317
coutI(Fitting) << "RooAbsPdf::fitTo(" << this->GetName()
13301318
<< ") Calculating sum-of-weights-squared correction matrix for covariance matrix"
13311319
<< std::endl;
13321320
minimizer.hesse();
13331321
std::unique_ptr<RooFitResult> rw2{minimizer.save()};
1334-
for (auto &comp : nllComponents) {
1335-
comp->applyWeightSquared(false);
1336-
}
1322+
nll.applyWeightSquared(false);
13371323

13381324
// Apply correction matrix
13391325
const TMatrixDSym &matV = rw->covarianceMatrix();

roofit/roofitcore/src/RooFitDriver.cxx

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ and gets destroyed when the fitting ends.
4040
#include <RooRealVar.h>
4141
#include <RooSimultaneous.h>
4242
#include <RooBatchCompute/Initialisation.h>
43+
#include <RooBatchCompute/DataKey.h>
4344

4445
#include <ROOT/StringUtils.hxx>
4546

@@ -48,6 +49,19 @@ and gets destroyed when the fitting ends.
4849
#include <thread>
4950
#include <unordered_set>
5051

52+
namespace {
53+
54+
// Little wrapper to use a TNamed directly as a RooBatchCompute DataKey
55+
class NamePtrWrapper {
56+
public:
57+
NamePtrWrapper(TNamed const *namePtr) : _namePtr(namePtr) {}
58+
operator RooBatchCompute::DataKey() const { return RooBatchCompute::DataKey::create(_namePtr); }
59+
60+
private:
61+
TNamed const *_namePtr;
62+
};
63+
} // namespace
64+
5165
namespace ROOT {
5266
namespace Experimental {
5367

@@ -66,6 +80,7 @@ RooFitDriver::Dataset::Dataset(RooAbsData const &data, RooArgSet const &observab
6680
RooAbsCategory const *indexCat)
6781
: _nEvents{static_cast<size_t>(data.numEntries())}
6882
{
83+
auto &nameReg = RooNameReg::instance();
6984

7085
// fill the RunContext with the observable data and map the observables
7186
// by namePtr in order to replace their memory addresses later, with
@@ -80,7 +95,7 @@ RooFitDriver::Dataset::Dataset(RooAbsData const &data, RooArgSet const &observab
8095
// the data map
8196
for (auto const &item : data.getCategoryBatches(0, _nEvents)) {
8297

83-
const TNamed *namePtr = RooNameReg::instance().constPtr(item.first.c_str());
98+
const TNamed *namePtr = nameReg.constPtr(item.first.c_str());
8499
RooSpan<const RooAbsCategory::value_type> intSpan{item.second};
85100

86101
_buffers.emplace(_nEvents);
@@ -92,19 +107,16 @@ RooFitDriver::Dataset::Dataset(RooAbsData const &data, RooArgSet const &observab
92107
_dataSpans[namePtr] = RooSpan<const double>(buffer, _nEvents);
93108
}
94109

95-
// Check if there is a batch for weights and if it's already in the dataMap.
96-
// If not, we need to put the batch and give as a key a RooRealVar* that has
97-
// the same name as RooNLLVarNew's _weight proxy, so that it gets renamed like
98-
// every other observable.
99-
RooSpan<const double> weights = data.getWeightBatch(0, _nEvents);
100-
if (!weights.empty()) {
101-
std::string weightVarName = "_weight";
102-
if (auto *dataSet = dynamic_cast<RooDataSet const *>(&data)) {
103-
if (dataSet->weightVar())
104-
weightVarName = dataSet->weightVar()->GetName();
110+
// Add weights to the datamap. They should have the names expected by the
111+
// RooNLLVarNew. We also add the sumW2 weights here under a different name,
112+
// so we can apply the sumW2 correction by easily swapping the spans.
113+
{
114+
auto weight = data.getWeightBatch(0, _nEvents, /*sumW2=*/false);
115+
auto weightSumW2 = data.getWeightBatch(0, _nEvents, /*sumW2=*/true);
116+
if (!weight.empty()) {
117+
_dataSpans[nameReg.constPtr(RooNLLVarNew::weightVarName)] = weight;
118+
_dataSpans[nameReg.constPtr(RooNLLVarNew::weightVarNameSumW2)] = weightSumW2;
105119
}
106-
const TNamed *pTNamed = RooNameReg::instance().constPtr(weightVarName.c_str());
107-
_dataSpans[pTNamed] = weights;
108120
}
109121

110122
// Now we have do do the range selection
@@ -210,8 +222,8 @@ RooFitDriver::RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, Ro
210222
}
211223

212224
RooFitDriver::RooFitDriver(RooBatchCompute::RunContext const &data, const RooAbsReal &topNode, RooArgSet const &normSet)
213-
: _name{topNode.GetName()}, _title{topNode.GetTitle()},
214-
_batchMode{RooFit::BatchModeOption::Cpu}, _dataset{data}, _topNode{topNode}, _normSet{std::make_unique<RooArgSet>(normSet)}
225+
: _name{topNode.GetName()}, _title{topNode.GetTitle()}, _batchMode{RooFit::BatchModeOption::Cpu}, _dataset{data},
226+
_topNode{topNode}, _normSet{std::make_unique<RooArgSet>(normSet)}
215227
{
216228
init();
217229
}
@@ -263,12 +275,15 @@ void RooFitDriver::init()
263275
serverSet.add(*arg);
264276
}
265277

278+
for (auto const &span : _dataset.spans()) {
279+
_dataMapCPU[NamePtrWrapper(span.first)] = span.second;
280+
}
281+
266282
for (RooAbsArg *arg : serverSet) {
267283
_orderedNodes.push_back(arg);
268284
auto &argInfo = _nodeInfos[arg];
269-
if (_dataset.contains(arg)) {
270-
_dataMapCPU[arg] = _dataset.span(arg);
271-
argInfo.outputSize = _dataset.span(arg).size();
285+
if (_dataMapCPU.count(arg) > 0) {
286+
argInfo.outputSize = _dataMapCPU[arg].size();
272287
}
273288

274289
for (auto *client : arg->clients()) {

roofit/roofitcore/src/RooNLLVarNew.cxx

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,66 @@ void RooNLLVarNew::getParametersHook(const RooArgSet * /*nset*/, RooArgSet *para
221221
if (_weight)
222222
params->remove(**_weight, true, true);
223223
}
224+
225+
////////////////////////////////////////////////////////////////////////////////
226+
/// Replaces all observables and the weight variable of this NLL with clones
227+
/// that only differ by a prefix added to the names. Used for simultaneous fits.
228+
/// \return A RooArgSet with the new observable args.
229+
/// \param[in] prefix The prefix to add to the observables and weight names.
230+
RooArgSet RooNLLVarNew::prefixObservableAndWeightNames(std::string const &prefix)
231+
{
232+
RooArgSet obsSet{_observables};
233+
if (_weight)
234+
obsSet.add(**_weight);
235+
RooArgSet obsClones;
236+
obsSet.snapshot(obsClones);
237+
for (RooAbsArg *arg : obsClones) {
238+
arg->setAttribute((std::string("ORIGNAME:") + arg->GetName()).c_str());
239+
arg->SetName((prefix + arg->GetName()).c_str());
240+
}
241+
recursiveRedirectServers(obsClones, false, true);
242+
243+
RooArgSet newObservables{obsClones};
244+
if (_weight) {
245+
newObservables.remove(**_weight);
246+
}
247+
248+
setObservables(obsClones);
249+
addOwnedComponents(std::move(obsClones));
250+
251+
return newObservables;
252+
}
253+
254+
namespace {
255+
256+
inline bool endsWith(std::string const &value, std::string_view ending)
257+
{
258+
if (ending.size() > value.size())
259+
return false;
260+
return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
261+
}
262+
263+
} // namespace
264+
265+
////////////////////////////////////////////////////////////////////////////////
266+
/// Toggles the weight square correction by changing the name of the weight
267+
/// variable to get different weights from the RooFitDriver data map. The
268+
/// RooNLLVarNew::weightVarNameSumW2Suffix is either added or removed from the
269+
/// weight variable name.
270+
void RooNLLVarNew::applyWeightSquared(bool flag)
271+
{
272+
if (!_weight)
273+
return;
274+
auto &w = **_weight;
275+
276+
const std::string name = w.GetName();
277+
const std::string suffix = weightVarNameSumW2Suffix;
278+
279+
bool isAlreadyWeightSquared = endsWith(name, suffix);
280+
281+
if (isAlreadyWeightSquared && !flag) {
282+
w.SetName(name.substr(0, name.length() - suffix.length()).c_str());
283+
} else if (!isAlreadyWeightSquared && flag) {
284+
w.SetName((name + suffix).c_str());
285+
}
286+
}

0 commit comments

Comments
 (0)