Skip to content

Commit

Permalink
[RF] Migrate more RooFit classes from evaluateSpan to computeBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
guitargeek committed Mar 4, 2022
1 parent 3db58d2 commit c45c7b6
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ParamHistFunc : public RooAbsReal {
Int_t addParamSet( const RooArgList& params );
static Int_t GetNumBins( const RooArgSet& vars );
double evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;

private:
static NumBins getNumBinsPerDim(RooArgSet const& vars);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class PiecewiseInterpolation : public RooAbsReal {
std::vector<int> _interpCode;

Double_t evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;

ClassDefOverride(PiecewiseInterpolation,4) // Sum of RooAbsReal objects
};
Expand Down
17 changes: 3 additions & 14 deletions roofit/histfactory/src/ParamHistFunc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -586,28 +586,19 @@ Double_t ParamHistFunc::evaluate() const
}


////////////////////////////////////////////////////////////////////////////////
/// Find all bins corresponding to the values of the observables in `evalData`, and evaluate
/// the associated parameters.
/// \param[in,out] evalData Input/output data for evaluating the ParamHistFunc.
/// \param[in] normSet Normalisation set passed on to objects that are serving values to us.
RooSpan<double> ParamHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
void ParamHistFunc::computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap& dataMap) const {
std::vector<double> oldValues;
std::vector<RooSpan<const double>> data;
std::size_t batchSize = 0;

// Retrieve data for all variables
for (auto arg : _dataVars) {
const auto* var = static_cast<RooRealVar*>(arg);
oldValues.push_back(var->getVal());
data.push_back(var->getValues(evalData, normSet));
batchSize = std::max(batchSize, data.back().size());
data.push_back(dataMap[var]);
}

// Run computation for each entry in the dataset
RooSpan<double> output = evalData.makeBatch(this, batchSize);

for (std::size_t i = 0; i < batchSize; ++i) {
for (std::size_t i = 0; i < size; ++i) {
for (unsigned int j = 0; j < _dataVars.size(); ++j) {
assert(i < data[j].size());
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
Expand All @@ -624,8 +615,6 @@ RooSpan<double> ParamHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalDat
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
var.setCachedValue(oldValues[j], /*notifyClients=*/false);
}

return output;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
21 changes: 10 additions & 11 deletions roofit/histfactory/src/PiecewiseInterpolation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,16 @@ Double_t PiecewiseInterpolation::evaluate() const
/// Interpolate between input distributions for all values of the observable in `evalData`.
/// \param[in,out] evalData Struct holding spans pointing to input data. The results of this function will be stored here.
/// \param[in] normSet Arguments to normalise over.
RooSpan<double> PiecewiseInterpolation::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
auto nominal = _nominal->getValues(evalData, normSet);
auto sum = evalData.makeBatch(this, nominal.size());
std::copy(nominal.begin(), nominal.end(), sum.begin());
void PiecewiseInterpolation::computeBatch(cudaStream_t*, double* sum, size_t /*size*/, RooBatchCompute::DataMap& dataMap) const {
auto nominal = dataMap[&*_nominal];
for(unsigned int j=0; j < nominal.size(); ++j) {
sum[j] = nominal[j];
}

for (unsigned int i=0; i < _paramSet.size(); ++i) {
const double param = static_cast<RooAbsReal*>(_paramSet.at(i))->getVal();
auto low = static_cast<RooAbsReal*>(_lowSet.at(i) )->getValues(evalData, normSet);
auto high = static_cast<RooAbsReal*>(_highSet.at(i))->getValues(evalData, normSet);
auto low = dataMap[_lowSet.at(i)];
auto high = dataMap[_highSet.at(i)];
const int icode = _interpCode[i];

switch(icode) {
Expand Down Expand Up @@ -436,13 +437,11 @@ RooSpan<double> PiecewiseInterpolation::evaluateSpan(RooBatchCompute::RunContext
}

if (_positiveDefinite) {
for (double& val : sum) {
if (val < 0.)
val = 0.;
for(unsigned int j=0; j < nominal.size(); ++j) {
if (sum[j] < 0.)
sum[j] = 0.;
}
}

return sum;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/inc/RooBinWidthFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class RooBinWidthFunction : public RooAbsReal {
bool divideByBinWidth() const { return _divideByBinWidth; }
const RooHistFunc& histFunc() const { return (*_histFunc); }
double evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;

private:
RooTemplateProxy<const RooHistFunc> _histFunc;
Expand Down
4 changes: 2 additions & 2 deletions roofit/roofitcore/inc/RooHistFunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ class RooHistFunc : public RooAbsReal {


Int_t getBin() const;
std::vector<Int_t> getBins(RooBatchCompute::RunContext& evalData) const;
std::vector<Int_t> getBins(RooBatchCompute::DataMap& dataMap) const;

protected:

Bool_t importWorkspaceHook(RooWorkspace& ws) override ;
Bool_t areIdentical(const RooDataHist& dh1, const RooDataHist& dh2) ;

Double_t evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* /*normSet*/) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;
friend class RooAbsCachedReal ;

void ioStreamerPass2() override ;
Expand Down
12 changes: 4 additions & 8 deletions roofit/roofitcore/src/RooBinWidthFunction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,18 @@ double RooBinWidthFunction::evaluate() const {
/// Compute bin index for all values of the observable(s) in `evalData`, and return their volumes or inverse volumes, depending
/// on the configuration chosen in the constructor.
/// If a bin is not valid, return a volume of 1.
RooSpan<double> RooBinWidthFunction::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* /*normSet*/) const {
void RooBinWidthFunction::computeBatch(cudaStream_t*, double* output, size_t, RooBatchCompute::DataMap& dataMap) const {
const RooDataHist& dataHist = _histFunc->dataHist();
std::vector<Int_t> bins = _histFunc->getBins(evalData);
std::vector<Int_t> bins = _histFunc->getBins(dataMap);
auto volumes = dataHist.binVolumes(0, dataHist.numEntries());

auto results = evalData.makeBatch(this, bins.size());

if (_divideByBinWidth) {
for (std::size_t i=0; i < bins.size(); ++i) {
results[i] = bins[i] >= 0 ? 1./volumes[bins[i]] : 1.;
output[i] = bins[i] >= 0 ? 1./volumes[bins[i]] : 1.;
}
} else {
for (std::size_t i=0; i < bins.size(); ++i) {
results[i] = bins[i] >= 0 ? volumes[bins[i]] : 1.;
output[i] = bins[i] >= 0 ? volumes[bins[i]] : 1.;
}
}

return results;
}
2 changes: 1 addition & 1 deletion roofit/roofitcore/src/RooFitDriver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void RooFitDriver::init()
// Some checks and logging of used architectures
{
auto log = [](std::string_view message) {
oocxcoutI(static_cast<RooAbsArg *>(nullptr), FastEvaluations) << message << std::endl;
oocxcoutI(static_cast<RooAbsArg *>(nullptr), Fitting) << message << std::endl;
};

if (_batchMode == RooFit::BatchModeOption::Cuda && !RooBatchCompute::dispatchCUDA) {
Expand Down
21 changes: 6 additions & 15 deletions roofit/roofitcore/src/RooHistFunc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,19 @@ Double_t RooHistFunc::evaluate() const
}


////////////////////////////////////////////////////////////////////////////////
/// Compute value of the HistFunc for every entry in `evalData`.
/// \param[in,out] evalData Struct with input data. The computation results will be stored here.
RooSpan<double> RooHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* /*normSet*/) const {
void RooHistFunc::computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap& dataMap) const {
std::vector<RooSpan<const double>> inputValues;
std::size_t batchSize = 0;
for (const auto& obs : _depList) {
auto realObs = dynamic_cast<const RooAbsReal*>(obs);
if (realObs) {
auto inputs = realObs->getValues(evalData, nullptr);
batchSize = std::max(batchSize, inputs.size());
auto inputs = dataMap[realObs];
inputValues.push_back(std::move(inputs));
} else {
inputValues.emplace_back();
}
}

auto results = evalData.makeBatch(this, batchSize);

for (std::size_t i = 0; i < batchSize; ++i) {
for (std::size_t i = 0; i < size; ++i) {
bool skip = false;

for (auto j = 0u; j < _histObsList.size(); ++j) {
Expand All @@ -236,10 +229,8 @@ RooSpan<double> RooHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalData,
}
}

results[i] = skip ? 0. : _dataHist->weightFast(_histObsList, _intOrder, false, _cdfBoundaries);
output[i] = skip ? 0. : _dataHist->weightFast(_histObsList, _intOrder, false, _cdfBoundaries);
}

return results;
}


Expand Down Expand Up @@ -606,12 +597,12 @@ Int_t RooHistFunc::getBin() const {
////////////////////////////////////////////////////////////////////////////////
/// Compute bin numbers corresponding to all coordinates in `evalData`.
/// \return Vector of bin numbers. If a bin is not in the current range of the observables, return -1.
std::vector<Int_t> RooHistFunc::getBins(RooBatchCompute::RunContext& evalData) const {
std::vector<Int_t> RooHistFunc::getBins(RooBatchCompute::DataMap& dataMap) const {
std::vector<RooSpan<const double>> depData;
for (const auto dep : _depList) {
auto real = dynamic_cast<const RooAbsReal*>(dep);
if (real) {
depData.push_back(real->getValues(evalData, nullptr));
depData.push_back(dataMap[real]);
} else {
depData.emplace_back(nullptr, 0);
}
Expand Down

0 comments on commit c45c7b6

Please sign in to comment.