Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion roofit/batchcompute/src/RooBatchCompute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public:
}
virtual void cudaStreamWaitEvent(cudaStream_t *stream, cudaEvent_t *event)
{
ERRCHECK(::cudaStreamWaitEvent(*stream, *event));
ERRCHECK(::cudaStreamWaitEvent(*stream, *event, 0));
}
virtual float cudaEventElapsedTime(cudaEvent_t *begin, cudaEvent_t *end)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ParamHistFunc : public RooAbsReal {
Int_t addParamSet( const RooArgList& params );
static Int_t GetNumBins( const RooArgSet& vars );
double evaluate() const override;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const override;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const override;

private:
static NumBins getNumBinsPerDim(RooArgSet const& vars);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class PiecewiseInterpolation : public RooAbsReal {
std::vector<int> _interpCode;

Double_t evaluate() const;
RooSpan<double> evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const;
void computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap&) const;

ClassDef(PiecewiseInterpolation,4) // Sum of RooAbsReal objects
};
Expand Down
12 changes: 3 additions & 9 deletions roofit/histfactory/src/ParamHistFunc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -591,23 +591,19 @@ Double_t ParamHistFunc::evaluate() const
/// the associated parameters.
/// \param[in/out] evalData Input/output data for evaluating the ParamHistFunc.
/// \param[in] normSet Normalisation set passed on to objects that are serving values to us.
RooSpan<double> ParamHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
void ParamHistFunc::computeBatch(cudaStream_t*, double* output, size_t size, RooBatchCompute::DataMap& dataMap) const {
std::vector<double> oldValues;
std::vector<RooSpan<const double>> data;
std::size_t batchSize = 0;

// Retrieve data for all variables
for (auto arg : _dataVars) {
const auto* var = static_cast<RooRealVar*>(arg);
oldValues.push_back(var->getVal());
data.push_back(var->getValues(evalData, normSet));
batchSize = std::max(batchSize, data.back().size());
data.push_back(dataMap[var]);
}

// Run computation for each entry in the dataset
RooSpan<double> output = evalData.makeBatch(this, batchSize);

for (std::size_t i = 0; i < batchSize; ++i) {
for (std::size_t i = 0; i < size; ++i) {
for (unsigned int j = 0; j < _dataVars.size(); ++j) {
assert(i < data[j].size());
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
Expand All @@ -624,8 +620,6 @@ RooSpan<double> ParamHistFunc::evaluateSpan(RooBatchCompute::RunContext& evalDat
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
var.setCachedValue(oldValues[j], /*notifyClients=*/false);
}

return output;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
21 changes: 10 additions & 11 deletions roofit/histfactory/src/PiecewiseInterpolation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,16 @@ Double_t PiecewiseInterpolation::evaluate() const
/// Interpolate between input distributions for all values of the observable in `evalData`.
/// \param[in/out] evalData Struct holding spans pointing to input data. The results of this function will be stored here.
/// \param[in] normSet Arguments to normalise over.
RooSpan<double> PiecewiseInterpolation::evaluateSpan(RooBatchCompute::RunContext& evalData, const RooArgSet* normSet) const {
auto nominal = _nominal->getValues(evalData, normSet);
auto sum = evalData.makeBatch(this, nominal.size());
std::copy(nominal.begin(), nominal.end(), sum.begin());
void PiecewiseInterpolation::computeBatch(cudaStream_t*, double* sum, size_t /*size*/, RooBatchCompute::DataMap& dataMap) const {
auto nominal = dataMap[&*_nominal];
for(unsigned int j=0; j < nominal.size(); ++j) {
sum[j] = nominal[j];
}

for (unsigned int i=0; i < _paramSet.size(); ++i) {
const double param = static_cast<RooAbsReal*>(_paramSet.at(i))->getVal();
auto low = static_cast<RooAbsReal*>(_lowSet.at(i) )->getValues(evalData, normSet);
auto high = static_cast<RooAbsReal*>(_highSet.at(i))->getValues(evalData, normSet);
auto low = dataMap[_lowSet.at(i)];
auto high = dataMap[_highSet.at(i)];
const int icode = _interpCode[i];

switch(icode) {
Expand Down Expand Up @@ -436,13 +437,11 @@ RooSpan<double> PiecewiseInterpolation::evaluateSpan(RooBatchCompute::RunContext
}

if (_positiveDefinite) {
for (double& val : sum) {
if (val < 0.)
val = 0.;
for(unsigned int j=0; j < nominal.size(); ++j) {
if (sum[j] < 0.)
sum[j] = 0.;
}
}

return sum;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion roofit/hs3/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitHS3
${RYMLSources}
DICTIONARY_OPTIONS
"-writeEmptyRootPCM"
DEPENDENCIES
LIBRARIES
RooFit
RooFitCore
RooStats
Expand Down
21 changes: 5 additions & 16 deletions roofit/hs3/inc/RooFitHS3/HistFactoryJSONTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,25 @@
#include <iostream>
#include <string>

namespace RooFit {
namespace Experimental {
class JSONNode;
}
} // namespace RooFit

namespace RooStats {
namespace HistFactory {

class Channel;
class Measurement;
class Sample;

class JSONTool {
protected:
RooStats::HistFactory::Measurement *_measurement;

void Export(const RooStats::HistFactory::Channel &c, RooFit::Experimental::JSONNode &t) const;
void Export(const RooStats::HistFactory::Sample &s, RooFit::Experimental::JSONNode &t) const;

public:
JSONTool(RooStats::HistFactory::Measurement *);
JSONTool(RooStats::HistFactory::Measurement &m) : _measurement(m) {}

void PrintJSON(std::ostream &os = std::cout);
void PrintJSON(std::string const &filename);
void PrintYAML(std::ostream &os = std::cout);
void PrintYAML(std::string const &filename);
void Export(RooFit::Experimental::JSONNode &t) const;

private:
RooStats::HistFactory::Measurement &_measurement;
};

} // namespace HistFactory
} // namespace RooStats

#endif
21 changes: 12 additions & 9 deletions roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
#ifndef RooFitHS3_RooJSONFactoryWSTool_h
#define RooFitHS3_RooJSONFactoryWSTool_h

#include <RooArgSet.h>
#include <RooGlobalFunc.h>

#include <map>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>

class RooArgList;
class RooAbsData;
class RooArgSet;
class RooAbsArg;
class RooAbsReal;
class RooAbsPdf;
Expand All @@ -30,6 +32,7 @@ class RooRealVar;
class RooWorkspace;

class TH1;
class TClass;

namespace RooFit {
namespace Experimental {
Expand Down Expand Up @@ -67,13 +70,13 @@ class RooJSONFactoryWSTool {
std::map<std::string, std::string> proxies;
};
struct ImportExpression {
TClass const* tclass = nullptr;
TClass const *tclass = nullptr;
std::vector<std::string> arguments;
};

typedef std::map<const std::string, std::vector<std::unique_ptr<const Importer>>> ImportMap;
typedef std::map<TClass const*, std::vector<std::unique_ptr<const Exporter>>> ExportMap;
typedef std::map<TClass const*, ExportKeys> ExportKeysMap;
typedef std::map<TClass const *, std::vector<std::unique_ptr<const Exporter>>> ExportMap;
typedef std::map<TClass const *, ExportKeys> ExportKeysMap;
typedef std::map<const std::string, ImportExpression> ImportExpressionMap;

// The following maps to hold the importers and exporters for runtime lookup
Expand All @@ -96,11 +99,11 @@ class RooJSONFactoryWSTool {
Var(const RooFit::Experimental::JSONNode &val);
};

std::ostream &log(RooFit::MsgLevel level) const;
std::ostream &log(int level) const;

protected:
struct Scope {
RooArgSet observables;
std::vector<RooAbsArg *> observables;
std::map<std::string, RooAbsArg *> objects;
};
mutable Scope _scope;
Expand Down Expand Up @@ -227,7 +230,7 @@ class RooJSONFactoryWSTool {
readBinnedData(const RooFit::Experimental::JSONNode &n, const std::string &namecomp, RooArgList observables);
static std::map<std::string, RooJSONFactoryWSTool::Var>
readObservables(const RooFit::Experimental::JSONNode &n, const std::string &obsnamecomp);
RooArgSet getObservables(const RooFit::Experimental::JSONNode &n, const std::string &obsnamecomp);
void getObservables(const RooFit::Experimental::JSONNode &n, const std::string &obsnamecomp, RooArgSet &out);
void setScopeObservables(const RooArgList &args);
RooAbsArg *getScopeObject(const std::string &name);
void setScopeObject(const std::string &key, RooAbsArg *obj);
Expand Down
Loading