Skip to content

Commit

Permalink
[RF] Support ShapeSys in JSON tool
Browse files Browse the repository at this point in the history
Refactor ParamHistFunc treatment to support both MC stat (BB-lite)
and user-defined ShapeSys

Validated on ATLAS VHbb workspace
  • Loading branch information
gartrog committed Jan 27, 2022
1 parent 2c6d8b2 commit 6455cd7
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 95 deletions.
285 changes: 192 additions & 93 deletions roofit/hs3/src/JSONFactories_HistFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class RooHistogramFactory : public RooJSONFactoryWSTool::Importer {
}
}
}

if (p.has_child("overallSystematics")) {
RooArgList nps;
std::vector<double> low;
Expand All @@ -227,6 +228,7 @@ class RooHistogramFactory : public RooJSONFactoryWSTool::Importer {
normElems.add(*v);
ownedArgsStack.push(std::move(v));
}

if (p.has_child("histogramSystematics")) {
RooArgList nps;
RooArgList low;
Expand Down Expand Up @@ -256,6 +258,20 @@ class RooHistogramFactory : public RooJSONFactoryWSTool::Importer {
shapeElems.add(*hf);
ownedArgsStack.push(std::move(hf));
}

if (p.has_child("shapeSystematics")) {
for (const auto &sys : p["shapeSystematics"].children()) {
std::string sysname(RooJSONFactoryWSTool::name(sys));
std::string funcName = prefix + sysname + "_ShapeSys";
RooAbsArg *phf = tool->getScopeObject(funcName);
if (!phf) {
RooJSONFactoryWSTool::error("PHF '" + funcName +
"' should have been created but cannot be found in scope.");
}
shapeElems.add(*phf);
}
}

RooProduct shape(name.c_str(), (name + "_shape").c_str(), shapeElems);
tool->workspace()->import(shape, RooFit::RecycleConflictNodes(true), RooFit::Silence(true));
if (normElems.size() > 0) {
Expand All @@ -274,90 +290,128 @@ class RooHistogramFactory : public RooJSONFactoryWSTool::Importer {

class RooRealSumPdfFactory : public RooJSONFactoryWSTool::Importer {
public:
std::unique_ptr<ParamHistFunc> createPHF(const std::string &name, const std::vector<double> &sumW,
const std::vector<double> &sumW2, RooArgList &constraints,
const RooArgSet &observables, double statErrorThreshold,
const std::string &statErrorType) const
std::unique_ptr<ParamHistFunc> createPHF(const std::string &sysname, const std::string &phfname, const std::vector<double> &vals,
RooWorkspace &w, RooArgList &constraints, const RooArgSet &observables,
const std::string &constraintType, RooArgList &gammas, double gamma_min,
double gamma_max) const
{
RooArgList gammas;
RooArgList nps;
RooArgList ownedComponents;
for (size_t i = 0; i < sumW.size(); ++i) {
TString gname = TString::Format("gamma_stat_%s_bin_%d", name.c_str(), (int)i);
double err = sqrt(sumW2[i]) / sumW[i];
auto g = std::make_unique<RooRealVar>(gname.Data(), gname.Data(), 1.);
if (err > 0) {
g->setAttribute("np");
g->setConstant(err < statErrorThreshold);
g->setError(err);
g->setMin(1. - 10 * err);
g->setMax(1. + 10 * err);
nps.add(*g);

if (statErrorType == "Gauss") {
TString tname = TString::Format("nom_gamma_stat_%s_bin_%d", name.c_str(), (int)i);
TString poisname = TString::Format("gamma_stat_%s_bin_%d_constraint", name.c_str(), (int)i);
TString sname = TString::Format("gamma_stat_%s_bin_%d_sigma", name.c_str(), (int)i);
auto tau = std::make_unique<RooRealVar>(tname.Data(), tname.Data(), 1);
tau->setAttribute("glob");
tau->setConstant(true);
tau->setRange(0, 10);
auto sigma = std::make_unique<RooConstVar>(sname.Data(), sname.Data(), err);
auto gaus = std::make_unique<RooGaussian>(poisname.Data(), poisname.Data(), *tau, *g, *sigma);
gaus->addOwnedComponents(std::move(tau), std::move(sigma));
constraints.add(*gaus, true);
ownedComponents.addOwned(std::move(gaus), true);
} else if (statErrorType == "Poisson") {
TString tname = TString::Format("tau_stat_%s_bin_%d", name.c_str(), (int)i);
TString prodname = TString::Format("nExp_stat_%s_bin_%d", name.c_str(), (int)i);
TString poisname = TString::Format("Constraint_stat_%s_bin_%d", name.c_str(), (int)i);
double tauCV = 1. / (err * err);
auto tau = std::make_unique<RooRealVar>(tname.Data(), tname.Data(), tauCV);
tau->setAttribute("glob");
tau->setConstant(true);
tau->setRange(tauCV - 10. / err, tauCV + 10. / err);
RooArgSet elems{*g, *tau};
auto prod = std::make_unique<RooProduct>(prodname.Data(), prodname.Data(), elems);
auto pois = std::make_unique<RooPoisson>(poisname.Data(), poisname.Data(), *tau, *prod);
pois->addOwnedComponents(std::move(tau), std::move(prod));
pois->setNoRounding(true);
constraints.add(*pois, true);
ownedComponents.addOwned(std::move(pois), true);
} else {
RooJSONFactoryWSTool::error("unknown constraint type " + statErrorType);
}
} else {
g->setConstant(true);

std::string funcParams = "gamma_" + sysname;
gammas.add(ParamHistFunc::createParamSet(w, funcParams.c_str(), observables, gamma_min, gamma_max));
auto phf = std::make_unique<ParamHistFunc>(phfname.c_str(), phfname.c_str(), observables, gammas);
for (auto &g : gammas) {
g->setAttribute("np");
}

if (constraintType == "Gauss") {
for (size_t i = 0; i < vals.size(); ++i) {
TString nomname = TString::Format("nom_%s", gammas[i].GetName());
TString poisname = TString::Format("%s_constraint", gammas[i].GetName());
TString sname = TString::Format("%s_sigma", gammas[i].GetName());
auto nom = std::make_unique<RooRealVar>(nomname.Data(), nomname.Data(), 1);
nom->setAttribute("glob");
nom->setConstant(true);
nom->setRange(0, std::max(10., gamma_max));
auto sigma = std::make_unique<RooConstVar>(sname.Data(), sname.Data(), vals[i]);
auto g = static_cast<RooRealVar *>(gammas.at(i));
auto gaus = std::make_unique<RooGaussian>(poisname.Data(), poisname.Data(), *nom, *g, *sigma);
gaus->addOwnedComponents(std::move(nom), std::move(sigma));
constraints.add(*gaus, true);
ownedComponents.addOwned(std::move(gaus), true);
}
} else if (constraintType == "Poisson") {
for (size_t i = 0; i < vals.size(); ++i) {
double tau_float = vals[i];
TString tname = TString::Format("%s_tau", gammas[i].GetName());
TString nomname = TString::Format("nom_%s", gammas[i].GetName());
TString prodname = TString::Format("%s_poisMean", gammas[i].GetName());
TString poisname = TString::Format("%s_constraint", gammas[i].GetName());
auto tau = std::make_unique<RooConstVar>(tname.Data(), tname.Data(), tau_float);
auto nom = std::make_unique<RooRealVar>(nomname.Data(), nomname.Data(), tau_float);
nom->setAttribute("glob");
nom->setConstant(true);
nom->setMin(0);
RooArgSet elems{gammas[i], *tau};
auto prod = std::make_unique<RooProduct>(prodname.Data(), prodname.Data(), elems);
auto pois = std::make_unique<RooPoisson>(poisname.Data(), poisname.Data(), *nom, *prod);
pois->addOwnedComponents(std::move(tau), std::move(nom), std::move(prod));
pois->setNoRounding(true);
constraints.add(*pois, true);
ownedComponents.addOwned(std::move(pois), true);
}
gammas.add(*g, true);
ownedComponents.addOwned(std::move(g), true);
} else {
RooJSONFactoryWSTool::error("unknown constraint type " + constraintType);
}
for (auto &np : nps) {
for (auto client : np->clients()) {
for (auto &g : gammas) {
for (auto client : g->clients()) {
if (client->InheritsFrom(RooAbsPdf::Class()) && !constraints.find(*client)) {
constraints.add(*client);
}
}
}
if (!gammas.empty()) {
auto phf = std::make_unique<ParamHistFunc>(TString::Format("%s_mcstat", name.c_str()), "staterror",
observables, gammas);
phf->recursiveRedirectServers(observables);
// Transfer ownership of gammas and owned constraints to the ParamHistFunc
phf->addOwnedComponents(std::move(ownedComponents));

return phf;
}

std::unique_ptr<ParamHistFunc> createPHFMCStat(const std::string &name, const std::vector<double> &sumW,
const std::vector<double> &sumW2, RooWorkspace &w,
RooArgList &constraints, const RooArgSet &observables, double statErrorThreshold,
const std::string &statErrorType) const
{
if (sumW.size() == 0)
return nullptr;

RooArgList gammas;
std::string phfname = std::string("mc_stat_") + name;
std::string sysname = std::string("stat_") + name;
std::vector<double> vals(sumW.size());
std::vector<double> errs(sumW.size());

for (size_t i = 0; i < sumW.size(); ++i) {
errs[i] = sqrt(sumW2[i]) / sumW[i];
if (statErrorType == "Gauss") {
vals[i] = std::max(errs[i], 0.); // avoid negative sigma. This NP will be set constant anyway later
} else if (statErrorType == "Poisson") {
vals[i] = sumW[i] * sumW[i] / sumW2[i];
}
}

phf->recursiveRedirectServers(observables);
auto phf = createPHF(sysname, phfname, vals, w, constraints, observables, statErrorType, gammas, 0, 10);

// Transfer ownership of gammas and owned constraints to the ParamHistFunc
phf->addOwnedComponents(std::move(ownedComponents));
// set constant NPs which are below the MC stat threshold, and remove them from the np list
for (size_t i = 0; i < sumW.size(); ++i) {
auto g = static_cast<RooRealVar *>(gammas.at(i));
g->setError(errs[i]);
if (errs[i] < statErrorThreshold) {
g->setConstant(true); // all negative errs are set constant
}
}

return phf;
}

return phf;
std::unique_ptr<ParamHistFunc> createPHFShapeSys(const JSONNode &p, const std::string &phfname, RooWorkspace &w,
RooArgList &constraints, const RooArgSet &observables) const
{
std::string sysname(RooJSONFactoryWSTool::name(p));
std::vector<double> vals;
for (const auto &v : p["vals"].children()) {
vals.push_back(v.val_float());
}
return nullptr;
RooArgList gammas;
return createPHF(sysname, phfname, vals, w, constraints, observables, p["constraint"].val(), gammas, 0, 1000);
}

bool importPdf(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
RooArgList funcs;
RooArgList coefs;
RooArgList constraints;
if (!p.has_child("samples")) {
RooJSONFactoryWSTool::error("no samples in '" + name + "', skipping.");
}
Expand Down Expand Up @@ -395,6 +449,20 @@ class RooRealSumPdfFactory : public RooJSONFactoryWSTool::Importer {
::collectNames(def["overallSystematics"], sysnames);
if (def.has_child("histogramSystematics"))
::collectNames(def["histogramSystematics"], sysnames);
if (def.has_child("shapeSystematics")) { // ShapeSys are special case. Create PHFs here if needed
std::vector<std::string> shapeSysNames;
::collectNames(def["shapeSystematics"], shapeSysNames);
for (auto &sysname : shapeSysNames) {
std::string phfname = name + "_" + sysname + "_ShapeSys";
auto phf = tool->getScopeObject(phfname);
if (!phf) {
auto newphf = createPHFShapeSys(def["shapeSystematics"][sysname], phfname, *(tool->workspace()),
constraints, observables);
tool->workspace()->import(*newphf, RooFit::RecycleConflictNodes(), RooFit::Silence(true));
tool->setScopeObject(phfname, tool->workspace()->function(phfname.c_str()));
}
}
}
} catch (const char *s) {
RooJSONFactoryWSTool::error("function '" + name + "' unable to collect observables from function " +
fname + ". " + s);
Expand All @@ -412,10 +480,8 @@ class RooRealSumPdfFactory : public RooJSONFactoryWSTool::Importer {
coefnames.push_back(fprefix + fname + "_norm");
}

RooArgList constraints;
auto phf = createPHF(name, sumW, sumW2, constraints, observables, statErrorThreshold, statErrorType);
phf->Print();
constraints.Print();
auto phf = createPHFMCStat(name, sumW, sumW2, *(tool->workspace()), constraints, observables,
statErrorThreshold, statErrorType);
if (phf) {
tool->workspace()->import(*phf, RooFit::RecycleConflictNodes(), RooFit::Silence(true));
tool->setScopeObject("mcstat", tool->workspace()->function(phf->GetName()));
Expand Down Expand Up @@ -444,7 +510,7 @@ class RooRealSumPdfFactory : public RooJSONFactoryWSTool::Importer {
tool->workspace()->import(sum, RooFit::RecycleConflictNodes(true), RooFit::Silence(true));
RooArgList lhelems;
lhelems.add(sum);
RooProdPdf prod(name.c_str(), name.c_str(), constraints, RooFit::Conditional(lhelems, observables));
RooProdPdf prod(name.c_str(), name.c_str(), RooArgSet(constraints), RooFit::Conditional(lhelems, observables));
tool->workspace()->import(prod, RooFit::RecycleConflictNodes(true), RooFit::Silence(true));
}

Expand Down Expand Up @@ -646,7 +712,7 @@ class HistFactoryStreamer : public RooJSONFactoryWSTool::Exporter {
elems.add(*coef);
}
std::unique_ptr<TH1> hist;
ParamHistFunc *phf = nullptr;
std::vector<ParamHistFunc *> phfs;
PiecewiseInterpolation *pip = nullptr;
RooStats::HistFactory::FlexibleInterpVar *fip = nullptr;
for (const auto &e : elems) {
Expand All @@ -673,7 +739,7 @@ class HistFactoryStreamer : public RooJSONFactoryWSTool::Exporter {
} else if (e->InheritsFrom(PiecewiseInterpolation::Class())) {
pip = static_cast<PiecewiseInterpolation *>(e);
} else if (e->InheritsFrom(ParamHistFunc::Class())) {
phf = (ParamHistFunc *)e;
phfs.push_back((ParamHistFunc *)e);
}
}
if (pip) {
Expand Down Expand Up @@ -719,31 +785,64 @@ class HistFactoryStreamer : public RooJSONFactoryWSTool::Exporter {
sys["high"] << fip->high()[i];
}
}
if (phf) {
s["statError"] << 1;
int idx = 0;
for (const auto &g : phf->paramList()) {
++idx;
RooPoisson *constraint_p = findClient<RooPoisson>(g);
RooGaussian *constraint_g = findClient<RooGaussian>(g);
if (tot_yield.find(idx) == tot_yield.end()) {
tot_yield[idx] = 0;
tot_yield2[idx] = 0;
bool has_mc_stat = false;
for (auto phf : phfs) {
if (TString(phf->GetName()).BeginsWith("mc_stat_")) { // MC stat uncertainty
has_mc_stat = true;
s["statError"] << 1;
int idx = 0;
for (const auto &g : phf->paramList()) {
++idx;
RooPoisson *constraint_p = findClient<RooPoisson>(g);
RooGaussian *constraint_g = findClient<RooGaussian>(g);
if (tot_yield.find(idx) == tot_yield.end()) {
tot_yield[idx] = 0;
tot_yield2[idx] = 0;
}
tot_yield[idx] += hist->GetBinContent(idx);
tot_yield2[idx] += (hist->GetBinContent(idx) * hist->GetBinContent(idx));
if (constraint_p) {
double erel = 1. / std::sqrt(constraint_p->getX().getVal());
rel_errors[idx] = erel;
has_poisson_constraints = true;
} else if (constraint_g) {
double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
rel_errors[idx] = erel;
has_gauss_constraints = true;
}
}
tot_yield[idx] += hist->GetBinContent(idx);
tot_yield2[idx] += (hist->GetBinContent(idx) * hist->GetBinContent(idx));
if (constraint_p) {
double erel = 1. / std::sqrt(constraint_p->getX().getVal());
rel_errors[idx] = erel;
has_poisson_constraints = true;
} else if (constraint_g) {
double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
rel_errors[idx] = erel;
has_gauss_constraints = true;
bb_histograms[samplename] = std::move(hist);
} else { // other ShapeSys
auto &shapesysts = s["shapeSystematics"];
shapesysts.set_map();
// Getting the name of the syst is tricky.
TString sysName(phf->GetName());
sysName.Remove(sysName.Index("_ShapeSys"));
sysName.Remove(0, chname.size() + 1);
auto &sys = shapesysts[sysName.Data()];
sys.set_map();
auto &cstrts = sys["vals"];
cstrts.set_seq();
bool is_poisson = false;
for (const auto &g : phf->paramList()) {
RooPoisson *constraint_p = findClient<RooPoisson>(g);
RooGaussian *constraint_g = findClient<RooGaussian>(g);
if (constraint_p) {
is_poisson = true;
cstrts.append_child() << constraint_p->getX().getVal();
} else if (constraint_g) {
is_poisson = false;
cstrts.append_child() << constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
}
}
if (is_poisson) {
sys["constraint"] << "Poisson";
} else {
sys["constraint"] << "Gauss";
}
}
bb_histograms[samplename] = std::move(hist);
} else {
}
if (!has_mc_stat) {
nonbb_histograms[samplename] = std::move(hist);
s["statError"] << 0;
}
Expand Down
4 changes: 2 additions & 2 deletions roofit/hs3/src/RooJSONFactoryWSTool.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ std::vector<std::vector<int>> RooJSONFactoryWSTool::generateBinIndices(const Roo
std::vector<std::vector<int>> combinations;
std::vector<int> vars_numbins;
vars_numbins.reserve(vars.size());
for (auto &absv : vars) {
vars_numbins.push_back(((RooRealVar *)absv)->numBins());
for (const auto &absv : static_range_cast<RooRealVar *>(vars)) {
vars_numbins.push_back(absv->numBins());
}
std::vector<int> curr_comb(vars.size());
::genIndicesHelper(combinations, curr_comb, vars_numbins, 0);
Expand Down

0 comments on commit 6455cd7

Please sign in to comment.