Skip to content

Commit

Permalink
[RF] Migrate RooProduct and RooRealSumPdf to computeBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
guitargeek committed Feb 26, 2022
1 parent 1094ad7 commit 5476f11
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 52 deletions.
2 changes: 1 addition & 1 deletion roofit/roofitcore/inc/RooProduct.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class RooProduct : public RooAbsReal {

Double_t calculate(const RooArgList& partIntList) const;
Double_t evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t nEvents, RooBatchCompute::DataMap&) const override;

const char* makeFPName(const char *pfx,const RooArgSet& terms) const ;
ProdMap* groupProductTerms(const RooArgSet&) const;
Expand Down
1 change: 0 additions & 1 deletion roofit/roofitcore/inc/RooRealSumPdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class RooRealSumPdf : public RooAbsPdf {
void setCacheAndTrackHints(RooArgSet&) override ;

protected:
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;

class CacheElem : public RooAbsCacheElement {
public:
Expand Down
34 changes: 10 additions & 24 deletions roofit/roofitcore/src/RooProduct.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -384,46 +384,32 @@ Double_t RooProduct::evaluate() const
}


////////////////////////////////////////////////////////////////////////////////
/// Evaluate product of input functions for all points found in `evalData`.
RooSpan<double> RooProduct::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
RooSpan<double> prod;

assert(_compRSet.nset() == normSet);
void RooProduct::computeBatch(cudaStream_t* /*stream*/, double* output, size_t nEvents, RooBatchCompute::DataMap& dataMap) const
{
for (unsigned int i = 0; i < nEvents; ++i) {
output[i] = 1.;
}

for (const auto item : _compRSet) {
auto rcomp = static_cast<const RooAbsReal*>(item);
auto componentValues = rcomp->getValues(evalData, normSet);

if (prod.empty()) {
prod = evalData.makeBatch(this, componentValues.size());
for (auto& val : prod) val = 1.;
} else if (prod.size() == 1 && componentValues.size() > 1) {
const double val = prod[0];
prod = evalData.makeBatch(this, componentValues.size());
std::fill(prod.begin(), prod.end(), val);
}
assert(prod.size() == componentValues.size() || componentValues.size() == 1);
auto componentValues = dataMap[rcomp];

for (unsigned int i = 0; i < prod.size(); ++i) {
prod[i] *= componentValues.size() == 1 ? componentValues[0] : componentValues[i];
for (unsigned int i = 0; i < nEvents; ++i) {
output[i] *= componentValues.size() == 1 ? componentValues[0] : componentValues[i];
}
}

for (const auto item : _compCSet) {
auto ccomp = static_cast<const RooAbsCategory*>(item);
const int catIndex = ccomp->getCurrentIndex();

for (unsigned int i = 0; i < prod.size(); ++i) {
prod[i] *= catIndex;
for (unsigned int i = 0; i < nEvents; ++i) {
output[i] *= catIndex;
}
}

return prod;
}



////////////////////////////////////////////////////////////////////////////////
/// Forward the plot sampling hint from the p.d.f. that defines the observable obs

Expand Down
37 changes: 11 additions & 26 deletions roofit/roofitcore/src/RooRealSumPdf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ Double_t RooRealSumPdf::evaluate() const
}


void RooRealSumPdf::computeBatch(cudaStream_t* stream, double* output, size_t nEvents, RooBatchCompute::DataMap& dataMap) const {
void RooRealSumPdf::computeBatch(cudaStream_t* /*stream*/, double* output, size_t nEvents, RooBatchCompute::DataMap& dataMap) const {

// To evaluate this RooRealSumPdf, we have to undo the normalization of the
// pdf servers by convention. TODO: find a less hacky solution for this,
Expand All @@ -283,34 +283,21 @@ void RooRealSumPdf::computeBatch(cudaStream_t* stream, double* output, size_t nE
}
}

RooAbsPdf::computeBatch(stream, output, nEvents, dataMapCopy);
}


////////////////////////////////////////////////////////////////////////////////
/// Calculate the value for all values of the observable in `evalData`.
RooSpan<double> RooRealSumPdf::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* /*normSet*/) const {
// Do running sum of coef/func pairs, calculate lastCoef.
RooSpan<double> values;
for (unsigned int j = 0; j < nEvents; ++j) {
output[j] = 0.0;
}

double sumCoeff = 0.;
for (unsigned int i = 0; i < _funcList.size(); ++i) {
const auto func = static_cast<RooAbsReal*>(&_funcList[i]);
const auto coef = static_cast<RooAbsReal*>(i < _coefList.size() ? &_coefList[i] : nullptr);
const double coefVal = coef != nullptr ? coef->getVal() : (1. - sumCoeff);

if (func->isSelectedComp()) {
auto funcValues = func->getValues(evalData, nullptr); // No normSet here, because we are summing functions!
if (values.empty() || (values.size() == 1 && funcValues.size() > 1)) {
const double init = values.empty() ? 0. : values[0];
values = evalData.makeBatch(this, funcValues.size());
for (unsigned int j = 0; j < values.size(); ++j) {
values[j] = init + funcValues[j] * coefVal;
}
} else {
assert(values.size() == funcValues.size());
for (unsigned int j = 0; j < values.size(); ++j) {
values[j] += funcValues[j] * coefVal;
}
auto funcValues = dataMapCopy[func];
for (unsigned int j = 0; j < nEvents; ++j) {
output[j] += funcValues[j] * coefVal;
}
}

Expand All @@ -323,20 +310,18 @@ RooSpan<double> RooRealSumPdf::evaluateSpan(RooBatchCompute::RunContext& evalDat
_haveWarned = true;
}
// Signal that we are in an undefined region by handing back one NaN.
values[0] = RooNaNPacker::packFloatIntoNaN(100.f * (coefVal < 0. ? -coefVal : coefVal - 1.));
output[0] = RooNaNPacker::packFloatIntoNaN(100.f * (coefVal < 0. ? -coefVal : coefVal - 1.));
}

sumCoeff += coefVal;
}

// Introduce floor if so requested
if (_doFloor || _doFloorGlobal) {
for (unsigned int j = 0; j < values.size(); ++j) {
values[j] += std::max(0., values[j]);
for (unsigned int j = 0; j < nEvents; ++j) {
output[j] += std::max(0., output[j]);
}
}

return values;
}


Expand Down

0 comments on commit 5476f11

Please sign in to comment.