Skip to content

Commit

Permalink
[RF] Make RooAbsReal::getValues use the batch mode via RooFitDriver
Browse files Browse the repository at this point in the history
The `RooAbsReal::getValues` has already been established as the entry
point for evaluating RooFit objects with the batch mode and it should
not be broken.

In 6.26, the `getValues` function was broken to fall back on the scalar
mode all the time, because the `evaluateSpan` funtions it used got
replaced by `computeBatch`. In this commit, the desired behavior of
using the BatchMoe is restored by using the RooFitDriver. To that end, a
new constructor has been added to the RooFitDriver that takes a
`RooBatchCompute::RunContext` directly.

The override of `getValues` in RooAbsPdf was also removed now, because
it's the job of the RooFitDriver to treat pdfs correctly.
  • Loading branch information
guitargeek committed Feb 25, 2022
1 parent c34dc26 commit ab474f7
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 115 deletions.
2 changes: 0 additions & 2 deletions roofit/roofitcore/inc/RooAbsPdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ class RooAbsPdf : public RooAbsReal {
Double_t getValV(const RooArgSet* set=0) const override ;
virtual Double_t getLogVal(const RooArgSet* set=0) const ;

RooSpan<const double> getValues(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
using RooAbsReal::getValues;
RooSpan<const double> getLogValBatch(std::size_t begin, std::size_t batchSize,
const RooArgSet* normSet = nullptr) const;
RooSpan<const double> getLogProbabilities(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet = nullptr) const;
Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/inc/RooAbsReal.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class RooAbsReal : public RooAbsArg {
#endif
/// \copydoc getValBatch(std::size_t, std::size_t, const RooArgSet*)
virtual RooSpan<const double> getValues(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet = nullptr) const;
std::vector<double> getValues(RooAbsData& data, RooFit::BatchModeOption batchMode=RooFit::BatchModeOption::Cpu) const;
std::vector<double> getValues(RooAbsData const& data, RooFit::BatchModeOption batchMode=RooFit::BatchModeOption::Cpu) const;

Double_t getPropagatedError(const RooFitResult &fr, const RooArgSet &nset = RooArgSet()) const;

Expand Down
19 changes: 14 additions & 5 deletions roofit/roofitcore/res/RooFitDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "RooAbsReal.h"
#include "RooBatchCompute.h"
#include "RooGlobalFunc.h"
#include "RunContext.h"

#include "RooFit/Detail/Buffers.h"

Expand Down Expand Up @@ -52,9 +53,9 @@ class RooFitDriver {
public:
class Dataset {
public:
Dataset(RooAbsData const &data, RooArgSet const &observables, std::string_view rangeName);

void splitByCategory(RooAbsCategory const &splitCategory);
Dataset(RooAbsData const &data, RooArgSet const &observables, std::string_view rangeName,
RooAbsCategory const *indexCat);
Dataset(RooBatchCompute::RunContext const &runContext);

std::size_t size() const { return _nEvents; }

Expand All @@ -73,14 +74,19 @@ class RooFitDriver {
std::map<const TNamed *, RooSpan<const double>> const &spans() const { return _dataSpans; }

private:
void splitByCategory(RooAbsCategory const &splitCategory);

std::map<const TNamed *, RooSpan<const double>> _dataSpans;
size_t _nEvents = 0;
std::stack<std::vector<double>> _buffers;
};

RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, RooArgSet const &observables,
RooArgSet const &normSet, RooFit::BatchModeOption batchMode, std::string_view rangeName,
RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, RooArgSet const &normSet,
RooFit::BatchModeOption batchMode, std::string_view rangeName,
RooAbsCategory const *indexCat = nullptr);

RooFitDriver(const RooBatchCompute::RunContext &runContext, const RooAbsReal &topNode, RooArgSet const &normSet);

~RooFitDriver();
std::vector<double> getValues();
double getVal();
Expand Down Expand Up @@ -187,6 +193,9 @@ class RooFitDriver {
RooBatchCompute::dispatchCUDA->deleteCudaStream(stream);
}
};

void init();

double getValHeterogeneous();
void updateMyClients(const RooAbsArg *node);
void updateMyServers(const RooAbsArg *node);
Expand Down
7 changes: 3 additions & 4 deletions roofit/roofitcore/src/BatchModeHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,10 @@ RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::uniqu
RooArgSet parameters;
pdf.getParameters(data.get(), parameters);
nll->recursiveRedirectServers(parameters);
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, observables, batchMode,
rangeName, &simPdf->indexCat());
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, batchMode, rangeName,
&simPdf->indexCat());
} else {
driver =
std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, observables, batchMode, rangeName);
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, batchMode, rangeName);
}

// Set the fitrange attribute so that RooPlot can automatically plot the fitting range by default
Expand Down
65 changes: 0 additions & 65 deletions roofit/roofitcore/src/RooAbsPdf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -362,71 +362,6 @@ Double_t RooAbsPdf::getValV(const RooArgSet* nset) const
}


////////////////////////////////////////////////////////////////////////////////
/// Compute batch of values for given input data, and normalise by integrating over
/// the observables in `normSet`. Store result in `evalData`, and return a span pointing to
/// it.
/// This uses evaluateSpan() to perform an (unnormalised) computation of data points. This computation
/// is finalised by normalising the bare values, and by checking for computation errors.
/// Derived classes should override evaluateSpan() to reach maximal performance.
///
/// \param[in,out] evalData Object holding data that should be used in computations. Results are also stored here.
/// \param[in] normSet If not nullptr, normalise results by integrating over
/// the variables in this set. The normalisation is only computed once, and applied
/// to the full batch.
/// \return RooSpan with probabilities. The memory of this span is owned by `evalData`.
/// \see RooAbsReal::getValues().
RooSpan<const double> RooAbsPdf::getValues(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
auto item = evalData.spans.find(this);
if (item != evalData.spans.end()) {
return item->second;
}

auto outputs = evaluateSpan(evalData, normSet);
assert(evalData.spans.count(this) > 0);

if (normSet != nullptr) {
if (normSet != _normSet || _norm == nullptr) {
syncNormalization(normSet);
}
// Evaluate denominator
// In most cases, the integral will be a scalar. But it can still happen
// that the integral is a vector, for example in a conditional fit where
// pdf is parametrized by another observable that is also a batch. That's
// why we have to use the batch interface also for the integral.
auto const& normVals = _norm->getValues(evalData);

if(normVals.size() > 1) {
for(std::size_t i = 0; i < outputs.size(); ++i) {
if (normVals[i] < 0. || (normVals[i] == 0. && outputs[i] != 0)) {
logEvalError(Form("p.d.f normalization integral is zero or negative."
"\n\tInt(%s) = %f", GetName(), normVals[i]));
}
if(normVals[i] != 1. && normVals[i] > 0.) {
outputs[i] /= normVals[i];
}
}

} else {
const double normVal = normVals[0];
if (normVal < 0.
|| (normVal == 0. && std::any_of(outputs.begin(), outputs.end(), [](double val){return val != 0;}))) {
logEvalError(Form("p.d.f normalization integral is zero or negative."
"\n\tInt(%s) = %f", GetName(), normVal));
}

if (normVal != 1. && normVal > 0.) {
const double invNorm = 1./normVal;
for (double& val : outputs) { //CHECK_VECTORISE
val *= invNorm;
}
}
}
}

return outputs;
}

////////////////////////////////////////////////////////////////////////////////
/// Analytical integral with normalization (see RooAbsReal::analyticalIntegralWN() for further information)
///
Expand Down
22 changes: 7 additions & 15 deletions roofit/roofitcore/src/RooAbsReal.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -315,27 +315,19 @@ RooSpan<const double> RooAbsReal::getValues(RooBatchCompute::RunContext& evalDat
return item->second;
}

if (normSet && normSet != _lastNSet) {
// TODO Implement better:
// The proxies, i.e. child nodes in the computation graph, sometimes need to know
// what to normalise over.
// Passing the normalisation as argument in all function calls is the proper way to do it.
// Some PDFs, however, might need to have the proxy normset set.
const_cast<RooAbsReal*>(this)->setProxyNormSet(normSet);
// TODO: This member only seems to be in use in RooFormulaVar. Try removing it (check with
// user community):
_lastNSet = (RooArgSet*) normSet;
}

auto results = evaluateSpan(evalData, normSet ? normSet : _lastNSet);
normSet = normSet ? normSet : _lastNSet;

ROOT::Experimental::RooFitDriver driver(evalData, *this, normSet ? *normSet : RooArgSet{});
auto& results = evalData.ownedMemory[this];
results = driver.getValues(); // the compiler should use the move assignment here
evalData.spans[this] = results;
return results;
}

////////////////////////////////////////////////////////////////////////////////

std::vector<double> RooAbsReal::getValues(RooAbsData& data, RooFit::BatchModeOption batchMode) const {
ROOT::Experimental::RooFitDriver driver(data, *this, *data.get(), *data.get(), batchMode, "");
std::vector<double> RooAbsReal::getValues(RooAbsData const& data, RooFit::BatchModeOption batchMode) const {
ROOT::Experimental::RooFitDriver driver(data, *this, *data.get(), batchMode, "");
return driver.getValues();
}

Expand Down
53 changes: 31 additions & 22 deletions roofit/roofitcore/src/RooFitDriver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ and gets destroyed when the fitting ends.
#include <RooRealVar.h>
#include <RooSimultaneous.h>
#include <RooBatchCompute/Initialisation.h>
#include <RunContext.h>

#include <ROOT/StringUtils.hxx>

Expand All @@ -54,7 +53,16 @@ namespace Experimental {
using namespace Detail;
using namespace std::chrono;

RooFitDriver::Dataset::Dataset(RooAbsData const &data, RooArgSet const &observables, std::string_view rangeName)
RooFitDriver::Dataset::Dataset(RooBatchCompute::RunContext const &runContext)
{
for (auto const &item : runContext.spans) {
_dataSpans[item.first->namePtr()] = item.second;
_nEvents = std::max(_nEvents, item.second.size());
}
}

RooFitDriver::Dataset::Dataset(RooAbsData const &data, RooArgSet const &observables, std::string_view rangeName,
RooAbsCategory const *indexCat)
: _nEvents{static_cast<size_t>(data.numEntries())}
{

Expand Down Expand Up @@ -131,6 +139,10 @@ RooFitDriver::Dataset::Dataset(RooAbsData const &data, RooArgSet const &observab
_dataSpans[item.first] = RooSpan<const double>{buffer, _nEvents};
}
}

if (indexCat) {
splitByCategory(*indexCat);
}
}

void RooFitDriver::Dataset::splitByCategory(RooAbsCategory const &category)
Expand Down Expand Up @@ -184,13 +196,26 @@ there's also some cuda-related initialization.
\param rangeName the range name
\param indexCat
**/
RooFitDriver::RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, RooArgSet const &observables,
RooArgSet const &normSet, RooFit::BatchModeOption batchMode, std::string_view rangeName,
RooFitDriver::RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, RooArgSet const &normSet,
RooFit::BatchModeOption batchMode, std::string_view rangeName,
RooAbsCategory const *indexCat)
: _name{topNode.GetName()}, _title{topNode.GetTitle()}, _parameters{*std::unique_ptr<RooArgSet>(
topNode.getParameters(*data.get(), true))},
_batchMode{batchMode}, _dataset{data, *std::unique_ptr<RooArgSet>(topNode.getObservables(data)), rangeName},
_batchMode{batchMode}, _dataset{data, *std::unique_ptr<RooArgSet>(topNode.getObservables(data)), rangeName,
indexCat},
_topNode{topNode}, _normSet{normSet}
{
init();
}

RooFitDriver::RooFitDriver(RooBatchCompute::RunContext const &data, const RooAbsReal &topNode, RooArgSet const &normSet)
: _name{topNode.GetName()}, _title{topNode.GetTitle()},
_batchMode{RooFit::BatchModeOption::Cpu}, _dataset{data}, _topNode{topNode}, _normSet{normSet}
{
init();
}

void RooFitDriver::init()
{
// Initialize RooBatchCompute
RooBatchCompute::init();
Expand Down Expand Up @@ -220,19 +245,7 @@ RooFitDriver::RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, Ro
// treeNodeServelList() is recursive and adds the top node before the children,
// so reversing the list gives us a topological ordering of the graph.
RooArgList serverList;
_topNode.treeNodeServerList(&serverList, nullptr, true, true, true);

// The treeNodeServerList recursion stops at fundamental RooAbsArgs, such as
// RooRealVar. But there can also be servers to fundamental types if they are
// not value servers but define for example the range. That's why we
// explicitely add the servers of the observables here.
for (auto *obs : observables) {
for (auto *server : obs->servers()) {
RooArgList tmp;
server->treeNodeServerList(&tmp, nullptr, true, true, true);
serverList.add(tmp);
}
}
_topNode.treeNodeServerList(&serverList, nullptr, true, true, false, true);

// To remove duplicates via the RooArgSet deduplication, we have to fill the
// set in reverse order because that's the dependency ordering of the graph.
Expand All @@ -249,10 +262,6 @@ RooFitDriver::RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, Ro
serverSet.add(*arg);
}

if (indexCat) {
_dataset.splitByCategory(*indexCat);
}

for (RooAbsArg *arg : serverSet) {
_orderedNodes.push_back(arg);
auto &argInfo = _nodeInfos[arg];
Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/test/testRooFitDriver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ TEST(testRooFitDriver, SimpleLikelihoodFit)

// ...and now the new way with RooFitDriver
ROOT::Experimental::RooNLLVarNew nll("nll", "nll", model, *data->get(), nullptr, false, "");
ROOT::Experimental::RooFitDriver driver(*data, nll, x, x, RooFit::BatchModeOption::Cpu, "");
ROOT::Experimental::RooFitDriver driver(*data, nll, x, RooFit::BatchModeOption::Cpu, "");
auto resultBatchNew = doFit(*driver.makeAbsRealWrapper());
if (verbose)
std::cout << "- batch mode fit took " << resultBatchNew.elapsedTime << " ms" << std::endl;
Expand Down

0 comments on commit ab474f7

Please sign in to comment.