Skip to content

Commit

Permalink
[RF] Reduce code duplication in RooAbsData::split()
Browse files Browse the repository at this point in the history
  • Loading branch information
guitargeek committed Oct 21, 2022
1 parent e4e870a commit b20f8b3
Showing 1 changed file with 107 additions and 152 deletions.
259 changes: 107 additions & 152 deletions roofit/roofitcore/src/RooAbsData.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,95 @@ TH1 *RooAbsData::fillHistogram(TH1 *hist, const RooArgList &plotVars, const char
return hist;
}


namespace {

struct SplittingSetup {
RooArgSet ownedSet;
RooAbsCategory *cloneCat = nullptr;
RooArgSet subsetVars;
bool addWeightVar = false;
};

SplittingSetup initSplit(RooAbsData const &data, RooAbsCategory const &splitCat)
{
SplittingSetup setup;

// Sanity check
if (!splitCat.dependsOn(*data.get())) {
oocoutE(&data, InputArguments) << "RooTreeData::split(" << data.GetName() << ") ERROR category "
<< splitCat.GetName() << " doesn't depend on any variable in this dataset"
<< std::endl;
return setup;
}

// Clone splitting category and attach to self
if (splitCat.isDerived()) {
RooArgSet(splitCat).snapshot(setup.ownedSet, true);
setup.cloneCat = (RooAbsCategory *)setup.ownedSet.find(splitCat.GetName());
setup.cloneCat->attachDataSet(data);
} else {
setup.cloneCat = dynamic_cast<RooAbsCategory *>(data.get()->find(splitCat.GetName()));
if (!setup.cloneCat) {
oocoutE(&data, InputArguments) << "RooTreeData::split(" << data.GetName() << ") ERROR category "
<< splitCat.GetName() << " is fundamental and does not appear in this dataset"
<< std::endl;
return setup;
}
}

// Construct set of variables to be included in split sets = full set - split category
setup.subsetVars.add(*data.get());
if (splitCat.isDerived()) {
std::unique_ptr<RooArgSet> vars{splitCat.getVariables()};
setup.subsetVars.remove(*vars, true, true);
} else {
setup.subsetVars.remove(splitCat, true, true);
}

// Add weight variable explicitly if dataset has weights, but no top-level weight
// variable exists (can happen with composite datastores)
if (data.isWeighted() && !data.IsA()->InheritsFrom(RooDataHist::Class())) {
auto newweight = std::make_unique<RooRealVar>("weight", "weight", -1e9, 1e9);
setup.subsetVars.add(*newweight);
setup.addWeightVar = true;
setup.ownedSet.addOwned(std::move(newweight));
}

return setup;
}

TList *splitImpl(RooAbsData const &data, const RooAbsCategory &cloneCat, bool createEmptyDataSets,
std::function<RooAbsData *(const char *label)> createEmptyData)
{
auto dsetList = new TList;

// If createEmptyDataSets is true, prepopulate with empty sets corresponding to all states
if (createEmptyDataSets) {
for (const auto &nameIdx : cloneCat) {
RooAbsData *subset = createEmptyData(nameIdx.first.c_str());
dsetList->Add((RooAbsArg *)subset);
}
}

// Loop over dataset and copy event to matching subset
const bool propWeightSquared = data.isWeighted();
for (Int_t i = 0; i < data.numEntries(); ++i) {
const RooArgSet *row = data.get(i);
RooAbsData *subset = (RooAbsData *)dsetList->FindObject(cloneCat.getCurrentLabel());
if (!subset) {
subset = createEmptyData(cloneCat.getCurrentLabel());
dsetList->Add((RooAbsArg *)subset);
}
subset->add(*row, data.weight(), propWeightSquared ? data.weightSquared() : 0.0);
}

return dsetList;
}

} // namespace


////////////////////////////////////////////////////////////////////////////////
/// Split dataset into subsets based on states of given splitCat in this dataset.
/// A TList of RooDataSets is returned in which each RooDataSet is named
Expand All @@ -1563,81 +1652,16 @@ TH1 *RooAbsData::fillHistogram(TH1 *hist, const RooArgList &plotVars, const char

TList* RooAbsData::split(const RooAbsCategory& splitCat, bool createEmptyDataSets) const
{
// Sanity check
if (!splitCat.dependsOn(*get())) {
coutE(InputArguments) << "RooTreeData::split(" << GetName() << ") ERROR category " << splitCat.GetName()
<< " doesn't depend on any variable in this dataset" << endl ;
return nullptr;
}

// Clone splitting category and attach to self
RooAbsCategory* cloneCat =0;
std::unique_ptr<RooArgSet> cloneSet;
if (splitCat.isDerived()) {
cloneSet.reset(static_cast<RooArgSet*>(RooArgSet(splitCat).snapshot(true)));
if (!cloneSet) {
coutE(InputArguments) << "RooTreeData::split(" << GetName() << ") Couldn't deep-clone splitting category, abort." << endl ;
return nullptr;
}
cloneCat = (RooAbsCategory*) cloneSet->find(splitCat.GetName()) ;
cloneCat->attachDataSet(*this) ;
} else {
cloneCat = dynamic_cast<RooAbsCategory*>(get()->find(splitCat.GetName())) ;
if (!cloneCat) {
coutE(InputArguments) << "RooTreeData::split(" << GetName() << ") ERROR category " << splitCat.GetName()
<< " is fundamental and does not appear in this dataset" << endl ;
return nullptr;
}
}
SplittingSetup setup = initSplit(*this, splitCat);

// Split a dataset in a series of subsets, each corresponding
// to a state of splitCat
TList* dsetList = new TList ;
// Something went wrong
if(!setup.cloneCat) return nullptr;

// Construct set of variables to be included in split sets = full set - split category
RooArgSet subsetVars(*get()) ;
if (splitCat.isDerived()) {
std::unique_ptr<RooArgSet> vars{splitCat.getVariables()};
subsetVars.remove(*vars,true,true) ;
} else {
subsetVars.remove(splitCat,true,true) ;
}

// Add weight variable explicitly if dataset has weights, but no top-level weight
// variable exists (can happen with composite datastores)
bool addWV(false) ;
RooRealVar newweight("weight","weight",-1e9,1e9) ;
if (isWeighted() && !IsA()->InheritsFrom(RooDataHist::Class())) {
subsetVars.add(newweight) ;
addWV = true ;
}

// If createEmptyDataSets is true, prepopulate with empty sets corresponding to all states
if (createEmptyDataSets) {
for (const auto& nameIdx : *cloneCat) {
RooAbsData* subset = emptyClone(nameIdx.first.c_str(), nameIdx.first.c_str(), &subsetVars,(addWV?"weight":0)) ;
dsetList->Add((RooAbsArg*)subset) ;
}
}


// Loop over dataset and copy event to matching subset
const bool propWeightSquared = isWeighted();
for (Int_t i = 0; i < numEntries(); ++i) {
const RooArgSet* row = get(i);
RooAbsData* subset = (RooAbsData*) dsetList->FindObject(cloneCat->getCurrentLabel());
if (!subset) {
subset = emptyClone(cloneCat->getCurrentLabel(),cloneCat->getCurrentLabel(),&subsetVars,(addWV?"weight":0));
dsetList->Add((RooAbsArg*)subset);
}
if (!propWeightSquared) {
subset->add(*row, weight());
} else {
subset->add(*row, weight(), weightSquared());
}
}
auto createEmptyData = [&](const char * label) -> RooAbsData* {
return emptyClone(label, label, &setup.subsetVars, setup.addWeightVar ? "weight" : nullptr);
};

return dsetList;
return splitImpl(*this, *setup.cloneCat, createEmptyDataSets, createEmptyData);
}

////////////////////////////////////////////////////////////////////////////////
Expand All @@ -1653,54 +1677,10 @@ TList* RooAbsData::split(const RooSimultaneous& simpdf, bool createEmptyDataSets
{
auto& splitCat = const_cast<RooAbsCategoryLValue&>(simpdf.indexCat());

// Sanity check
if (!splitCat.dependsOn(*get())) {
coutE(InputArguments) << "RooTreeData::split(" << GetName() << ") ERROR category " << splitCat.GetName()
<< " doesn't depend on any variable in this dataset" << endl ;
return nullptr;
}

// Clone splitting category and attach to self
RooAbsCategory* cloneCat =0;
std::unique_ptr<RooArgSet> cloneSet;
if (splitCat.isDerived()) {
cloneSet.reset(static_cast<RooArgSet*>(RooArgSet(splitCat).snapshot(true)));
if (!cloneSet) {
coutE(InputArguments) << "RooTreeData::split(" << GetName() << ") Couldn't deep-clone splitting category, abort." << endl ;
return nullptr;
}
cloneCat = (RooAbsCategory*) cloneSet->find(splitCat.GetName()) ;
cloneCat->attachDataSet(*this) ;
} else {
cloneCat = dynamic_cast<RooAbsCategory*>(get()->find(splitCat.GetName())) ;
if (!cloneCat) {
coutE(InputArguments) << "RooTreeData::split(" << GetName() << ") ERROR category " << splitCat.GetName()
<< " is fundamental and does not appear in this dataset" << endl ;
return nullptr;
}
}

// Split a dataset in a series of subsets, each corresponding
// to a state of splitCat
TList* dsetList = new TList ;
SplittingSetup setup = initSplit(*this, splitCat);

// Construct set of variables to be included in split sets = full set - split category
RooArgSet subsetVars(*get()) ;
if (splitCat.isDerived()) {
std::unique_ptr<RooArgSet> vars{splitCat.getVariables()};
subsetVars.remove(*vars,true,true) ;
} else {
subsetVars.remove(splitCat,true,true) ;
}

// Add weight variable explicitly if dataset has weights, but no top-level weight
// variable exists (can happen with composite datastores)
bool addWV(false) ;
RooRealVar newweight("weight","weight",-1e9,1e9) ;
if (isWeighted() && !IsA()->InheritsFrom(RooDataHist::Class())) {
subsetVars.add(newweight) ;
addWV = true ;
}
// Something went wrong
if(!setup.cloneCat) return nullptr;

// Get the observables for a given pdf in the RooSimultaneous, or an empty
// RooArgSet if no pdf is set
Expand All @@ -1717,41 +1697,16 @@ TList* RooAbsData::split(const RooSimultaneous& simpdf, bool createEmptyDataSets
for( const auto& catPair : splitCat) {
allObservables.add(getPdfObservables(catPair.first.c_str()));
}
subsetVars.remove(allObservables, true, true);

setup.subsetVars.remove(allObservables, true, true);

// If createEmptyDataSets is true, prepopulate with empty sets corresponding to all states
if (createEmptyDataSets) {
for (const auto& nameIdx : *cloneCat) {
// Add in the subset only the observables corresponding to this category
RooArgSet subsetVarsCat(subsetVars);
subsetVarsCat.add(getPdfObservables(nameIdx.first.c_str()));
RooAbsData* subset = emptyClone(nameIdx.first.c_str(), nameIdx.first.c_str(), &subsetVarsCat,(addWV?"weight":0)) ;
dsetList->Add(subset) ;
}
}


// Loop over dataset and copy event to matching subset
const bool propWeightSquared = isWeighted();
for (Int_t i = 0; i < numEntries(); ++i) {
const RooArgSet* row = get(i);
RooAbsData* subset = (RooAbsData*) dsetList->FindObject(cloneCat->getCurrentLabel());
if (!subset) {
// Add in the subset only the observables corresponding to this category
RooArgSet subsetVarsCat(subsetVars);
subsetVarsCat.add(getPdfObservables(cloneCat->getCurrentLabel()));
subset = emptyClone(cloneCat->getCurrentLabel(),cloneCat->getCurrentLabel(),&subsetVarsCat,(addWV?"weight":0));
dsetList->Add(subset);
}
if (!propWeightSquared) {
subset->add(*row, weight());
} else {
subset->add(*row, weight(), weightSquared());
}
}
auto createEmptyData = [&](const char * label) -> RooAbsData* {
// Add in the subset only the observables corresponding to this category
RooArgSet subsetVarsCat(setup.subsetVars);
subsetVarsCat.add(getPdfObservables(label));
return this->emptyClone(label, label, &subsetVarsCat, setup.addWeightVar ? "weight" : nullptr);
};

return dsetList;
return splitImpl(*this, *setup.cloneCat, createEmptyDataSets, createEmptyData);
}

////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit b20f8b3

Please sign in to comment.