Skip to content

Commit

Permalink
[RF] Add support for CompositeDataStore in new BatchMode
Browse files Browse the repository at this point in the history
  • Loading branch information
guitargeek committed Oct 21, 2022
1 parent b20f8b3 commit f53f6f2
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 79 deletions.
2 changes: 1 addition & 1 deletion roofit/roofitcore/res/RooFit/BatchModeDataHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace RooFit {
namespace BatchModeDataHelpers {

std::map<RooFit::Detail::DataKey, RooSpan<const double>>
getDataSpans(RooAbsData const &data, std::string_view rangeName, RooAbsCategory const *indexCat,
getDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix,
std::stack<std::vector<double>> &buffers, bool skipZeroWeights);

} // namespace BatchModeDataHelpers
Expand Down
2 changes: 2 additions & 0 deletions roofit/roofitcore/res/RooFitDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class RooFitDriver {

// RAII structures to reset state of computation graph after driver destruction
std::stack<RooHelpers::ChangeOperModeRAII> _changeOperModeRAIIs;

std::vector<std::unique_ptr<RooAbsData>> _splittedDataSets;
};

} // end namespace Experimental
Expand Down
99 changes: 23 additions & 76 deletions roofit/roofitcore/src/BatchModeDataHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,13 @@

#include <RooFit/BatchModeDataHelpers.h>

#include <RooAbsCategory.h>
#include <RooAbsData.h>
#include <RooNLLVarNew.h>

#include <ROOT/StringUtils.hxx>

#include <numeric>

namespace {

void splitByCategory(std::map<RooFit::Detail::DataKey, RooSpan<const double>> &dataSpans,
RooAbsCategory const &category, std::stack<std::vector<double>> &buffers)
{
std::stack<std::vector<double>> oldBuffers;
std::swap(buffers, oldBuffers);

auto catVals = dataSpans.at(category.namePtr());

std::map<RooFit::Detail::DataKey, RooSpan<const double>> dataMapSplit;

for (auto const &dataMapItem : dataSpans) {

auto const &varNamePtr = dataMapItem.first;
auto const &xVals = dataMapItem.second;

if (varNamePtr == category.namePtr())
continue;

std::map<RooAbsCategory::value_type, std::vector<double>> valuesMap;

if (xVals.size() == 1) {
// If the span is of size one, we will replicate it for each category
// component instead of splitting is up by category value.
for (auto const &catItem : category) {
valuesMap[catItem.second].push_back(xVals[0]);
}
} else {
for (std::size_t i = 0; i < xVals.size(); ++i) {
valuesMap[catVals[i]].push_back(xVals[i]);
}
}

for (auto const &item : valuesMap) {
RooAbsCategory::value_type index = item.first;
auto variableName = std::string("_") + category.lookupName(index) + "_" + varNamePtr->GetName();
auto variableNamePtr = RooNameReg::instance().constPtr(variableName.c_str());

buffers.emplace(std::move(item.second));
auto const &values = buffers.top();
dataMapSplit[variableNamePtr] = RooSpan<const double>(values.data(), values.size());
}
}

dataSpans = std::move(dataMapSplit);
}

} // namespace

////////////////////////////////////////////////////////////////////////////////
/// Extract all content from a RooFit datasets as a map of spans.
/// Spans with the weights and squared weights will be also stored in the map,
Expand All @@ -81,11 +30,8 @@ void splitByCategory(std::map<RooFit::Detail::DataKey, RooSpan<const double>> &d
/// \param[in] data The input dataset.
/// \param[in] rangeName Select only entries from the data in a given range
/// (empty string for no range).
/// \param[in] indexCat If not `nullptr`, each span is spit up by this category,
/// with the new names prefixed by the category component name
/// surrounded by underscores. For example, if you have a category
/// with `signal` and `control` samples, the span for a variable `x`
/// will be split in two spans `_signal_x` and `_control_x`.
/// \param[in] prefix A string prefix to use for all key names for the data
/// map.
/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
/// be used as memory for the data if the memory in the dataset
/// object can't be used directly (e.g. because you used the range
Expand All @@ -96,24 +42,31 @@ void splitByCategory(std::map<RooFit::Detail::DataKey, RooSpan<const double>> &d
/// original dataset anymore!
std::map<RooFit::Detail::DataKey, RooSpan<const double>>
RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string_view rangeName,
RooAbsCategory const *indexCat, std::stack<std::vector<double>> &buffers,
std::string const &prefix, std::stack<std::vector<double>> &buffers,
bool skipZeroWeights)
{
std::map<RooFit::Detail::DataKey, RooSpan<const double>> dataSpans; // output variable

auto &nameReg = RooNameReg::instance();

auto insert = [&](const char *key, RooSpan<const double> span) {
const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
dataSpans[namePtr] = span;
};

auto retrieve = [&](const char *key) {
const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
return dataSpans.at(namePtr);
};

std::size_t nEvents = static_cast<size_t>(data.numEntries());

// We also want to support empty datasets: in this case the
// RooFitDriver::Dataset is not filled with anything.
if (nEvents == 0)
if (nEvents == 0) {
return dataSpans;

if (!buffers.empty()) {
throw std::invalid_argument("The buffers container must be empty when passed to getDataSpans()!");
}

auto &nameReg = RooNameReg::instance();

auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);

Expand Down Expand Up @@ -154,15 +107,14 @@ RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string_v
weightSumW2 = RooSpan<const double>(bufferSumW2.data(), nNonZeroWeight);
}
using namespace ROOT::Experimental;
dataSpans[nameReg.constPtr(RooNLLVarNew::weightVarName)] = weight;
dataSpans[nameReg.constPtr(RooNLLVarNew::weightVarNameSumW2)] = weightSumW2;
insert(RooNLLVarNew::weightVarName, weight);
insert(RooNLLVarNew::weightVarNameSumW2, weightSumW2);
}

// Get the real-valued batches and cast the also to double branches to put in
// the data map
for (auto const &item : data.getBatches(0, nEvents)) {

const TNamed *namePtr = nameReg.constPtr(item.first->GetName());
RooSpan<const double> span{item.second};

buffers.emplace();
Expand All @@ -174,14 +126,13 @@ RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string_v
buffer.push_back(span[i]);
}
}
dataSpans[namePtr] = RooSpan<const double>(buffer.data(), buffer.size());
insert(item.first->GetName(), {buffer.data(), buffer.size()});
}

// Get the category batches and cast the also to double branches to put in
// the data map
for (auto const &item : data.getCategoryBatches(0, nEvents)) {

const TNamed *namePtr = nameReg.constPtr(item.first->GetName());
RooSpan<const RooAbsCategory::value_type> intSpan{item.second};

buffers.emplace();
Expand All @@ -193,7 +144,7 @@ RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string_v
buffer.push_back(static_cast<double>(intSpan[i]));
}
}
dataSpans[namePtr] = RooSpan<const double>(buffer.data(), buffer.size());
insert(item.first->GetName(), {buffer.data(), buffer.size()});
}

nEvents = nNonZeroWeight;
Expand All @@ -206,9 +157,9 @@ RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string_v
std::vector<bool> isInSubRange(nEvents, true);
for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
// If the observables is not real-valued, it will not be considered for the range selection
if (!observable)
continue;
observable->inRange({dataSpans.at(observable->namePtr()).data(), nEvents}, range, isInSubRange);
if (observable) {
observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
}
}
for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
isInRange[i] = isInRange[i] || isInSubRange[i];
Expand Down Expand Up @@ -237,9 +188,5 @@ RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string_v
}
}

if (indexCat) {
splitByCategory(dataSpans, *indexCat, buffers);
}

return dataSpans;
}
27 changes: 25 additions & 2 deletions roofit/roofitcore/src/RooFitDriver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ RooAbsPdf::fitTo() is called and gets destroyed when the fitting ends.

#include "NormalizationHelpers.h"

#include <TList.h>

#include <iomanip>
#include <numeric>
#include <thread>
Expand Down Expand Up @@ -195,10 +197,31 @@ void RooFitDriver::setData(RooAbsData const &data, std::string_view rangeName,
RooAbsCategory const *indexCatForSplitting, bool skipZeroWeights,
bool takeGlobalObservablesFromData)
{
std::vector<std::pair<std::string, RooAbsData const *>> datas;

if (indexCatForSplitting) {
std::unique_ptr<TList> splits{data.split(*indexCatForSplitting, true)};
for (auto *d : static_range_cast<RooAbsData *>(*splits)) {
std::string prefix = std::string("_") + d->GetName() + "_";
datas.emplace_back(prefix, d);
_splittedDataSets.emplace_back(d);
}
} else {
datas.emplace_back("", &data);
}

DataSpansMap dataSpans;

std::stack<std::vector<double>>{}.swap(_vectorBuffers);
DataSpansMap dataSpans = RooFit::BatchModeDataHelpers::getDataSpans(data, rangeName, indexCatForSplitting,
_vectorBuffers, skipZeroWeights);

for (auto const &toAdd : datas) {
DataSpansMap spans = RooFit::BatchModeDataHelpers::getDataSpans(*toAdd.second, rangeName, toAdd.first,
_vectorBuffers, skipZeroWeights);
for (auto const &item : spans) {
dataSpans.insert(item);
}
}

if (takeGlobalObservablesFromData && data.getGlobalObservables()) {
_vectorBuffers.emplace();
auto &buffer = _vectorBuffers.top();
Expand Down

0 comments on commit f53f6f2

Please sign in to comment.