Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RF] Make RooAbsReal::getValues use the batch mode via RooFitDriver #9986

Merged
merged 4 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ParamHistFunc : public RooAbsReal {
Int_t addParamSet( const RooArgList& params );
static Int_t GetNumBins( const RooArgSet& vars );
double evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;

private:
static NumBins getNumBinsPerDim(RooArgSet const& vars);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class PiecewiseInterpolation : public RooAbsReal {
std::vector<int> _interpCode;

Double_t evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;

ClassDefOverride(PiecewiseInterpolation,4) // Sum of RooAbsReal objects
};
Expand Down
12 changes: 3 additions & 9 deletions roofit/histfactory/src/ParamHistFunc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -591,23 +591,19 @@ Double_t ParamHistFunc::evaluate() const
/// the associated parameters.
/// \param[in,out] evalData Input/output data for evaluating the ParamHistFunc.
/// \param[in] normSet Normalisation set passed on to objects that are serving values to us.
RooSpan<double> ParamHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
void ParamHistFunc::computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap& dataMap) const {
std::vector<double> oldValues;
std::vector<RooSpan<const double>> data;
std::size_t batchSize = 0;

// Retrieve data for all variables
for (auto arg : _dataVars) {
const auto* var = static_cast<RooRealVar*>(arg);
oldValues.push_back(var->getVal());
data.push_back(var->getValues(evalData, normSet));
batchSize = std::max(batchSize, data.back().size());
data.push_back(dataMap[var]);
}

// Run computation for each entry in the dataset
RooSpan<double> output = evalData.makeBatch(this, batchSize);

for (std::size_t i = 0; i < batchSize; ++i) {
for (std::size_t i = 0; i < size; ++i) {
for (unsigned int j = 0; j < _dataVars.size(); ++j) {
assert(i < data[j].size());
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
Expand All @@ -624,8 +620,6 @@ RooSpan<double> ParamHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalDat
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
var.setCachedValue(oldValues[j], /*notifyClients=*/false);
}

return output;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
21 changes: 10 additions & 11 deletions roofit/histfactory/src/PiecewiseInterpolation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,16 @@ Double_t PiecewiseInterpolation::evaluate() const
/// Interpolate between input distributions for all values of the observable in `evalData`.
/// \param[in,out] evalData Struct holding spans pointing to input data. The results of this function will be stored here.
/// \param[in] normSet Arguments to normalise over.
RooSpan<double> PiecewiseInterpolation::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
auto nominal = _nominal->getValues(evalData, normSet);
auto sum = evalData.makeBatch(this, nominal.size());
std::copy(nominal.begin(), nominal.end(), sum.begin());
void PiecewiseInterpolation::computeBatch(cudaStream_t*, double* sum, size_t /*size*/, RooBatchCompute::DataMap& dataMap) const {
auto nominal = dataMap[&*_nominal];
for(unsigned int j=0; j < nominal.size(); ++j) {
sum[j] = nominal[j];
}

for (unsigned int i=0; i < _paramSet.size(); ++i) {
const double param = static_cast<RooAbsReal*>(_paramSet.at(i))->getVal();
auto low = static_cast<RooAbsReal*>(_lowSet.at(i) )->getValues(evalData, normSet);
auto high = static_cast<RooAbsReal*>(_highSet.at(i))->getValues(evalData, normSet);
auto low = dataMap[_lowSet.at(i)];
auto high = dataMap[_highSet.at(i)];
const int icode = _interpCode[i];

switch(icode) {
Expand Down Expand Up @@ -436,13 +437,11 @@ RooSpan<double> PiecewiseInterpolation::evaluateSpan(RooBatchCompute::RunContext
}

if (_positiveDefinite) {
for (double& val : sum) {
if (val < 0.)
val = 0.;
for(unsigned int j=0; j < nominal.size(); ++j) {
if (sum[j] < 0.)
sum[j] = 0.;
}
}

return sum;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/inc/RooAbsArg.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ class RooAbsArg : public TNamed, public RooPrintable {
virtual bool canComputeBatchWithCuda() const { return false; }
virtual bool isReducerNode() const { return false; }

operator RooBatchCompute::DataKey() const { return RooBatchCompute::DataKey::create(this); }
operator RooBatchCompute::DataKey() const { return RooBatchCompute::DataKey::create(this->namePtr()); }

protected:
void graphVizAddConnections(std::set<std::pair<RooAbsArg*,RooAbsArg*> >&) ;
Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/inc/RooAbsReal.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class RooAbsReal : public RooAbsArg {
#endif
/// \copydoc getValBatch(std::size_t, std::size_t, const RooArgSet*)
virtual RooSpan<const double> getValues(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet = nullptr) const;
std::vector<double> getValues(RooAbsData& data, RooFit::BatchModeOption batchMode=RooFit::BatchModeOption::Cpu) const;
std::vector<double> getValues(RooAbsData const& data, RooFit::BatchModeOption batchMode=RooFit::BatchModeOption::Cpu) const;

Double_t getPropagatedError(const RooFitResult &fr, const RooArgSet &nset = RooArgSet()) const;

Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/inc/RooBinWidthFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class RooBinWidthFunction : public RooAbsReal {
bool divideByBinWidth() const { return _divideByBinWidth; }
const RooHistFunc& histFunc() const { return (*_histFunc); }
double evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;

private:
RooTemplateProxy<const RooHistFunc> _histFunc;
Expand Down
4 changes: 2 additions & 2 deletions roofit/roofitcore/inc/RooHistFunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ class RooHistFunc : public RooAbsReal {


Int_t getBin() const;
std::vector<Int_t> getBins(RooBatchCompute::RunContext& evalData) const;
std::vector<Int_t> getBins(RooBatchCompute::DataMap& dataMap) const;

protected:

Bool_t importWorkspaceHook(RooWorkspace& ws) override ;
Bool_t areIdentical(const RooDataHist& dh1, const RooDataHist& dh2) ;

Double_t evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* /*normSet*/) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;
friend class RooAbsCachedReal ;

void ioStreamerPass2() override ;
Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/inc/RooProduct.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class RooProduct : public RooAbsReal {

Double_t calculate(const RooArgList& partIntList) const;
Double_t evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t nEvents, RooBatchCompute::DataMap&) const override;

const char* makeFPName(const char *pfx,const RooArgSet& terms) const ;
ProdMap* groupProductTerms(const RooArgSet&) const;
Expand Down
1 change: 0 additions & 1 deletion roofit/roofitcore/inc/RooRealSumPdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class RooRealSumPdf : public RooAbsPdf {
void setCacheAndTrackHints(RooArgSet&) override ;

protected:
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;

class CacheElem : public RooAbsCacheElement {
public:
Expand Down
31 changes: 23 additions & 8 deletions roofit/roofitcore/res/RooFitDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "RooAbsReal.h"
#include "RooBatchCompute.h"
#include "RooGlobalFunc.h"
#include "RunContext.h"

#include "RooFit/Detail/Buffers.h"

Expand Down Expand Up @@ -52,9 +53,9 @@ class RooFitDriver {
public:
class Dataset {
public:
Dataset(RooAbsData const &data, RooArgSet const &observables, std::string_view rangeName);

void splitByCategory(RooAbsCategory const &splitCategory);
Dataset(RooAbsData const &data, RooArgSet const &observables, std::string_view rangeName,
RooAbsCategory const *indexCat);
Dataset(RooBatchCompute::RunContext const &runContext);

std::size_t size() const { return _nEvents; }

Expand All @@ -73,14 +74,19 @@ class RooFitDriver {
std::map<const TNamed *, RooSpan<const double>> const &spans() const { return _dataSpans; }

private:
void splitByCategory(RooAbsCategory const &splitCategory);

std::map<const TNamed *, RooSpan<const double>> _dataSpans;
size_t _nEvents = 0;
std::stack<std::vector<double>> _buffers;
};

RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, RooArgSet const &observables,
RooArgSet const &normSet, RooFit::BatchModeOption batchMode, std::string_view rangeName,
RooFitDriver(const RooAbsData &data, const RooAbsReal &topNode, RooArgSet const &normSet,
RooFit::BatchModeOption batchMode, std::string_view rangeName,
RooAbsCategory const *indexCat = nullptr);

RooFitDriver(const RooBatchCompute::RunContext &runContext, const RooAbsReal &topNode, RooArgSet const &normSet);

~RooFitDriver();
std::vector<double> getValues();
double getVal();
Expand Down Expand Up @@ -159,6 +165,8 @@ class RooFitDriver {
}
}

RooAbsArg *absArg = nullptr;

std::unique_ptr<Detail::AbsBuffer> buffer;

cudaEvent_t *event = nullptr;
Expand All @@ -185,6 +193,10 @@ class RooFitDriver {
RooBatchCompute::dispatchCUDA->deleteCudaStream(stream);
}
};

void init();

double getValHeterogeneous();
void updateMyClients(const RooAbsArg *node);
void updateMyServers(const RooAbsArg *node);
void handleIntegral(const RooAbsArg *node);
Expand Down Expand Up @@ -221,9 +233,12 @@ class RooFitDriver {
RooBatchCompute::DataMap _dataMapCPU;
RooBatchCompute::DataMap _dataMapCUDA;
const RooAbsReal &_topNode;
RooArgSet _normSet;
std::unordered_map<const RooAbsArg *, NodeInfo> _nodeInfos;
std::unordered_map<const RooAbsArg *, NodeInfo> _integralInfos;
std::unique_ptr<RooArgSet> _normSet;
std::map<RooAbsArg const *, NodeInfo> _nodeInfos;
std::map<RooAbsArg const *, NodeInfo> _integralInfos;

// the ordered computation graph
std::vector<RooAbsArg *> _orderedNodes;

// used for preserving resources
std::vector<double> _nonDerivedValues;
Expand Down
7 changes: 3 additions & 4 deletions roofit/roofitcore/src/BatchModeHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,10 @@ RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::uniqu
RooArgSet parameters;
pdf.getParameters(data.get(), parameters);
nll->recursiveRedirectServers(parameters);
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, observables, batchMode,
rangeName, &simPdf->indexCat());
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, batchMode, rangeName,
&simPdf->indexCat());
} else {
driver =
std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, observables, batchMode, rangeName);
driver = std::make_unique<ROOT::Experimental::RooFitDriver>(data, *nll, observables, batchMode, rangeName);
}

// Set the fitrange attribute so that RooPlot can automatically plot the fitting range by default
Expand Down
57 changes: 9 additions & 48 deletions roofit/roofitcore/src/RooAbsPdf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -377,54 +377,15 @@ Double_t RooAbsPdf::getValV(const RooArgSet* nset) const
/// \return RooSpan with probabilities. The memory of this span is owned by `evalData`.
/// \see RooAbsReal::getValues().
RooSpan<const double> RooAbsPdf::getValues(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
auto item = evalData.spans.find(this);
if (item != evalData.spans.end()) {
return item->second;
}

auto outputs = evaluateSpan(evalData, normSet);
assert(evalData.spans.count(this) > 0);

if (normSet != nullptr) {
if (normSet != _normSet || _norm == nullptr) {
syncNormalization(normSet);
}
// Evaluate denominator
// In most cases, the integral will be a scalar. But it can still happen
// that the integral is a vector, for example in a conditional fit where
// pdf is parametrized by another observable that is also a batch. That's
// why we have to use the batch interface also for the integral.
auto const& normVals = _norm->getValues(evalData);

if(normVals.size() > 1) {
for(std::size_t i = 0; i < outputs.size(); ++i) {
if (normVals[i] < 0. || (normVals[i] == 0. && outputs[i] != 0)) {
logEvalError(Form("p.d.f normalization integral is zero or negative."
"\n\tInt(%s) = %f", GetName(), normVals[i]));
}
if(normVals[i] != 1. && normVals[i] > 0.) {
outputs[i] /= normVals[i];
}
}

} else {
const double normVal = normVals[0];
if (normVal < 0.
|| (normVal == 0. && std::any_of(outputs.begin(), outputs.end(), [](double val){return val != 0;}))) {
logEvalError(Form("p.d.f normalization integral is zero or negative."
"\n\tInt(%s) = %f", GetName(), normVal));
}

if (normVal != 1. && normVal > 0.) {
const double invNorm = 1./normVal;
for (double& val : outputs) { //CHECK_VECTORISE
val *= invNorm;
}
}
}
}

return outputs;
// To avoid side effects of this function, the pointer to the last norm
// sets and integral objects are remembered and reset at the end of this
// function.
auto * prevNorm = _norm;
auto * prevNormSet = _normSet;
auto out = RooAbsReal::getValues(evalData, normSet);
_norm = prevNorm;
_normSet = prevNormSet;
return out;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
22 changes: 7 additions & 15 deletions roofit/roofitcore/src/RooAbsReal.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -315,27 +315,19 @@ RooSpan<const double> RooAbsReal::getValues(RooBatchCompute::RunContext& evalDat
return item->second;
}

if (normSet && normSet != _lastNSet) {
// TODO Implement better:
// The proxies, i.e. child nodes in the computation graph, sometimes need to know
// what to normalise over.
// Passing the normalisation as argument in all function calls is the proper way to do it.
// Some PDFs, however, might need to have the proxy normset set.
const_cast<RooAbsReal*>(this)->setProxyNormSet(normSet);
// TODO: This member only seems to be in use in RooFormulaVar. Try removing it (check with
// user community):
_lastNSet = (RooArgSet*) normSet;
}

auto results = evaluateSpan(evalData, normSet ? normSet : _lastNSet);
normSet = normSet ? normSet : _lastNSet;

ROOT::Experimental::RooFitDriver driver(evalData, *this, normSet ? *normSet : RooArgSet{});
auto& results = evalData.ownedMemory[this];
results = driver.getValues(); // the compiler should use the move assignment here
evalData.spans[this] = results;
return results;
}

////////////////////////////////////////////////////////////////////////////////

std::vector<double> RooAbsReal::getValues(RooAbsData& data, RooFit::BatchModeOption batchMode) const {
ROOT::Experimental::RooFitDriver driver(data, *this, *data.get(), *data.get(), batchMode, "");
std::vector<double> RooAbsReal::getValues(RooAbsData const& data, RooFit::BatchModeOption batchMode) const {
ROOT::Experimental::RooFitDriver driver(data, *this, *data.get(), batchMode, "");
return driver.getValues();
}

Expand Down
12 changes: 4 additions & 8 deletions roofit/roofitcore/src/RooBinWidthFunction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,18 @@ double RooBinWidthFunction::evaluate() const {
/// Compute bin index for all values of the observable(s) in `evalData`, and return their volumes or inverse volumes, depending
/// on the configuration chosen in the constructor.
/// If a bin is not valid, return a volume of 1.
RooSpan<double> RooBinWidthFunction::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* /*normSet*/) const {
void RooBinWidthFunction::computeBatch(cudaStream_t*, double* output, size_t, RooBatchCompute::DataMap& dataMap) const {
const RooDataHist& dataHist = _histFunc->dataHist();
std::vector<Int_t> bins = _histFunc->getBins(evalData);
std::vector<Int_t> bins = _histFunc->getBins(dataMap);
auto volumes = dataHist.binVolumes(0, dataHist.numEntries());

auto results = evalData.makeBatch(this, bins.size());

if (_divideByBinWidth) {
for (std::size_t i=0; i < bins.size(); ++i) {
results[i] = bins[i] >= 0 ? 1./volumes[bins[i]] : 1.;
output[i] = bins[i] >= 0 ? 1./volumes[bins[i]] : 1.;
}
} else {
for (std::size_t i=0; i < bins.size(); ++i) {
results[i] = bins[i] >= 0 ? volumes[bins[i]] : 1.;
output[i] = bins[i] >= 0 ? volumes[bins[i]] : 1.;
}
}

return results;
}
Loading