Skip to content

Commit

Permalink
Multi target support
Browse files Browse the repository at this point in the history
  • Loading branch information
fradav committed Jan 25, 2021
1 parent e8d3a6c commit 9b4adff
Show file tree
Hide file tree
Showing 12 changed files with 486 additions and 290 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ endif()

if(TEST)
add_subdirectory(test)
add_subdirectory(testpy)
# add_subdirectory(testpy)
else()
message("Skipping tests")
endif()
Expand Down
18 changes: 16 additions & 2 deletions abcranger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,27 @@ int main(int argc, char* argv[]) {
chosenscen = static_cast<double>(opts["chosenscen"].as<size_t>());

auto myread = readreftable_scen(headerfile, reftablefile, chosenscen, nref);
const auto statobs = readStatObs(statobsfile);
auto origobs = readStatObs(statobsfile);
size_t nstat = myread.stats_names.size();
size_t num_samples = origobs.size() / nstat;
if (((origobs.size() % nstat) != 0) || (num_samples < 1)) {
std::cout << "wrong number of summary statistics in statobs file." << std::endl;
exit(1);
}
MatrixXd statobs = Map<MatrixXd>(origobs.data(),nstat,num_samples).transpose();
auto res = EstimParam_fun(myread, statobs, opts);
} else {
std::cout << "> Model Choice <" << std::endl;

auto myread = readreftable(headerfile, reftablefile, nref, false, opts.count("g") > 0 ? opts["g"].as<std::string>() : "");
const auto statobs = readStatObs(statobsfile);
auto origobs = readStatObs(statobsfile);
size_t nstat = myread.stats_names.size();
size_t num_samples = origobs.size() / nstat;
if (((origobs.size() % nstat) != 0) || (num_samples < 1)) {
std::cout << "wrong number of summary statistics in statobs file." << std::endl;
exit(1);
}
MatrixXd statobs = Map<MatrixXd>(origobs.data(),nstat,num_samples).transpose();
auto res = ModelChoice_fun(myread,statobs,opts);

}
Expand Down
588 changes: 340 additions & 248 deletions src/EstimParam.cpp

Large diffs are not rendered by default.

38 changes: 23 additions & 15 deletions src/EstimParam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,33 @@
#include <map>
#include "Reftable.hpp"
#include "cxxopts.hpp"
#include <Eigen/Dense>

struct EstimParamResults {
struct EstimParamResults
{
std::vector<double> plsvar;
std::vector<std::pair<std::string,double>> plsweights;
std::vector<std::pair<std::string,double>> variable_importance;
std::vector<std::pair<std::string, double>> plsweights;
std::vector<std::pair<std::string, double>> variable_importance;
std::vector<double> ntree_oob_error;
std::vector<std::pair<double,double>> values_weights;
std::map<size_t,size_t> oob_map;
std::vector<std::pair<double, double>> values_weights;
std::map<size_t, size_t> oob_map;
Eigen::MatrixXd oob_weights;
std::map<std::string,double> point_estimates;
std::map<std::string, // Global/Local
std::map<std::string, // Mean/Median/CI
std::map<std::string,double>>> errors;
std::vector<std::map<std::string, double>> point_estimates;
std::map<std::string, // Global/Local
std::vector<std::map<std::string, // Mean/Median/CI
std::map<std::string, double>>>>
errors;
};

template<class MatrixType>
template <class MatrixType>
EstimParamResults EstimParam_fun(Reftable<MatrixType> &reftable,
std::vector<double> statobs,
const cxxopts::ParseResult opts,
bool quiet = false,
bool weights = false);
MatrixXd statobs,
const cxxopts::ParseResult opts,
bool quiet = false,
bool weights = false);

template <class MatrixType>
EstimParamResults EstimParam_fun(Reftable<MatrixType> &reftable,
std::vector<double> statobs,
const cxxopts::ParseResult opts,
bool quiet = false,
bool weights = false);
3 changes: 2 additions & 1 deletion src/ForestOnlineRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ void ForestOnlineRegression::predictInternal(size_t tree_idx)
predictions[1][sample_idx][tree_idx] = static_cast<double>(getTreePredictionTerminalNodeID(tree_idx, sample_idx));
else {
auto value = getTreePrediction(tree_idx, sample_idx);
if (std::isnan(value)) throw std::runtime_error("NaN value");
// if (std::isnan(value)) throw std::runtime_error("NaN value");
if (std::isnan(value)) next;
mutex_samples[sample_idx].lock();
predictions[1][0][sample_idx] += value;
mutex_samples[sample_idx].unlock();
Expand Down
64 changes: 47 additions & 17 deletions src/ModelChoice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace ranges;

template<class MatrixType>
ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &myread,
std::vector<double> obs,
MatrixXd statobs,
const cxxopts::ParseResult opts,
bool quiet)
{
Expand All @@ -43,12 +43,11 @@ ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &myread,
std::vector<double> samplefract{std::min(1e5,static_cast<double>(myread.nrec))/static_cast<double>(myread.nrec)};
auto nstat = myread.stats_names.size();
size_t K = myread.nrecscen.size();
MatrixXd statobs(1, nstat);
MatrixXd emptyrow(1,0);
size_t num_samples = statobs.rows();

size_t n = myread.nrec;

statobs = Map<MatrixXd>(obs.data(), 1, nstat);

MatrixXd data_extended(n,0);

Expand Down Expand Up @@ -92,7 +91,7 @@ ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &myread,
for(auto i = 0; i < varwithouty.size(); i++) varwithouty[i] = myread.stats_names[i];


auto datastatobs = unique_cast<DataDense<MatrixXd>, Data>(std::make_unique<DataDense<MatrixXd>>(statobs, emptyrow, varwithouty, 1, varwithouty.size()));
auto datastatobs = unique_cast<DataDense<MatrixXd>, Data>(std::make_unique<DataDense<MatrixXd>>(statobs, emptyrow, varwithouty, num_samples, varwithouty.size()));
auto datastats = unique_cast<DataDense<MatrixType>, Data>(std::make_unique<DataDense<MatrixType>>(myread.stats, data_extended, myread.stats_names, myread.nrec, myread.stats_names.size()));
ForestOnlineClassification forestclass;
forestclass.init("Y", // dependant variable
Expand Down Expand Up @@ -143,12 +142,13 @@ ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &myread,
res.ntree_oob_error = preds[2][0];
if (!quiet) forestclass.writeOOBErrorFile();

vector<size_t> votes(K);
for(auto& tree_pred : preds[1][0]) votes[static_cast<size_t>(tree_pred-1)]++;
res.votes = votes;

size_t predicted_model = std::distance(votes.begin(),std::max_element(votes.begin(),votes.end()));
res.predicted_model = predicted_model;
size_t nobs = statobs.rows();
res.votes = std::vector<std::vector<size_t>>(num_samples,std::vector<size_t>(K));
res.predicted_model = std::vector<size_t>(num_samples);
for(auto i = 0; i < num_samples; i++) {
for(auto& tree_pred : preds[1][i]) res.votes[i][static_cast<size_t>(tree_pred-1)]++;
res.predicted_model[i] = std::distance(res.votes[i].begin(),std::max_element(res.votes[i].begin(),res.votes[i].end()));
}

size_t ycol = data_extended.cols() - 1;

Expand Down Expand Up @@ -219,20 +219,26 @@ ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &myread,
myread.stats_names.resize(nstat);

auto predserr = forestreg.getPredictions();
res.post_proba = predserr[1][0][0];
res.post_proba = std::vector<double>(num_samples);
for(auto i = 0; i < num_samples; i++) res.post_proba[i] = predserr[1][0][i];
const std::string& predict_filename = outfile + ".predictions";

std::ostringstream os;
for(auto i = 0; i < votes.size(); i++) {
if (num_samples > 1) os << fmt::format("{:>14}", "Target n°");
for(auto i = 0; i < K; i++) {
os << fmt::format("{:>14}",fmt::format("votes model{0}",i+1));
}
os << fmt::format(" selected model");
os << fmt::format(" post proba\n");
for(auto i = 0; i < votes.size(); i++) {
os << fmt::format("{:>14}",votes[i]);
os << fmt::format(" post proba\n");
for (auto j = 0; j < num_samples; j++) {
if (num_samples > 1)
os << fmt::format("{:>14}", j + 1);
for(auto i = 0; i < K; i++) {
os << fmt::format(" {:>13}",res.votes[j][i]);
}
os << fmt::format("{:>15}", res.predicted_model[j] + 1);
os << fmt::format("{:12.3f}\n",res.post_proba[j]);
}
os << fmt::format("{:>15}", predicted_model + 1);
os << fmt::format("{:11.3f}\n",res.post_proba);
if (!quiet) std::cout << os.str();
std::cout.flush();

Expand Down Expand Up @@ -277,6 +283,30 @@ ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &myread,
return res;
}

template<class MatrixType>
ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &myread,
std::vector<double> origobs,
const cxxopts::ParseResult opts,
bool quiet)
{
auto nstat = myread.stats_names.size();
MatrixXd statobs(1, nstat);
statobs = Map<MatrixXd>(origobs.data(), 1, nstat);
return ModelChoice_fun(myread,statobs,opts,quiet);
}

template
ModelChoiceResults ModelChoice_fun(Reftable<MatrixXd> &myread,
MatrixXd obs,
const cxxopts::ParseResult opts,
bool quiet);

template
ModelChoiceResults ModelChoice_fun(Reftable<Eigen::Ref<MatrixXd, 0, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>> &myread,
MatrixXd obs,
const cxxopts::ParseResult opts,
bool quiet);

template
ModelChoiceResults ModelChoice_fun(Reftable<MatrixXd> &myread,
std::vector<double> obs,
Expand Down
12 changes: 9 additions & 3 deletions src/ModelChoice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ struct ModelChoiceResults
std::vector<std::vector<size_t>> confusion_matrix;
std::vector<std::pair<std::string, double>> variable_importance;
std::vector<double> ntree_oob_error;
size_t predicted_model;
std::vector<size_t> votes;
double post_proba;
std::vector<size_t> predicted_model;
std::vector<std::vector<size_t>> votes;
std::vector<double> post_proba;
};

template<class MatrixType>
ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &reftable,
MatrixXd statobs,
const cxxopts::ParseResult opts,
bool quiet = false);

template<class MatrixType>
ModelChoiceResults ModelChoice_fun(Reftable<MatrixType> &reftable,
std::vector<double> statobs,
Expand Down
10 changes: 10 additions & 0 deletions src/pls-eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ using namespace Eigen;
using namespace std;
using namespace ranges;

template<class Derived>
std::vector<size_t> filterConstantVars(const MatrixBase<Derived>& xr) {
auto meanr = xr.colwise().mean();
auto stdr = ((xr.rowwise() - meanr).array().square().colwise().sum() / (xr.rows() - 1)).sqrt();;
std::vector<size_t> validvars;
for(size_t i = 0; i< xr.cols(); i++) {
if (stdr(i) >= 1.0e-8) validvars.push_back(i);
}
return validvars;
}

template<class Derived, class OtherDerived>
VectorXd pls(const MatrixBase<Derived>& x,
Expand Down
21 changes: 19 additions & 2 deletions src/pyabcranger/pyabcranger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,29 @@ ModelChoiceResults ModelChoice_fun_py(Reftable<py::EigenDRef<MatrixXd>> &reftabl
return ModelChoice_fun(reftable,statobs,parseopt(options),quiet);
}

ModelChoiceResults ModelChoice_multi_fun_py(Reftable<py::EigenDRef<MatrixXd>> &reftable,
MatrixXd& statobs,
std::string options,
bool quiet = false) {
return ModelChoice_fun(reftable,statobs,parseopt(options),quiet);
}

EstimParamResults EstimParam_multi_fun_py(Reftable<py::EigenDRef<MatrixXd>> &reftable,
MatrixXd& statobs,
std::string options,
bool quiet = false,
bool weights = false) {
return EstimParam_fun(reftable,statobs,parseopt(options),quiet,weights);
}

EstimParamResults EstimParam_fun_py(Reftable<py::EigenDRef<MatrixXd>> &reftable,
std::vector<double> statobs,
std::string options,
bool quiet = false,
bool weights = false) {
return EstimParam_fun(reftable,statobs,parseopt(options),quiet,weights);
}
}

using namespace Eigen;

PYBIND11_MODULE(pyabcranger, m) {
Expand Down Expand Up @@ -101,7 +116,9 @@ PYBIND11_MODULE(pyabcranger, m) {
.def_readwrite("errors",&EstimParamResults::errors);

m.def("modelchoice", &ModelChoice_fun_py, py::call_guard<py::scoped_ostream_redirect,py::gil_scoped_release>());
m.def("modelchoice_multi", &ModelChoice_multi_fun_py, py::call_guard<py::scoped_ostream_redirect,py::gil_scoped_release>());
m.def("estimparam", &EstimParam_fun_py, py::call_guard<py::scoped_ostream_redirect,py::gil_scoped_release>());
m.def("estimparam_multi", &EstimParam_multi_fun_py, py::call_guard<py::scoped_ostream_redirect,py::gil_scoped_release>());
m.def("forestQuantiles_b", [](const std::vector<double>& obs,
const std::vector<std::vector<double>>& weights,
const std::vector<double>& asked){
Expand Down
4 changes: 4 additions & 0 deletions test/data/statobsRF2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ML1p_1 ML1p_2 ML1p_3 ML1p_4 ML2p_1.2 ML2p_1.3 ML2p_1.4 ML2p_2.3 ML2p_2.4 ML2p_3.4 ML3p_1.2.3 ML3p_1.2.4 ML3p_1.3.4 ML3p_2.3.4 HWm_1 HWm_2 HWm_3 HWm_4 HWv_1 HWv_2 HWv_3 HWv_4 HBm_1.2 HBm_1.3 HBm_1.4 HBm_2.3 HBm_2.4 HBm_3.4 HBv_1.2 HBv_1.3 HBv_1.4 HBv_2.3 HBv_2.4 HBv_3.4 FST1m_1 FST1m_2 FST1m_3 FST1m_4 FST1v_1 FST1v_2 FST1v_3 FST1v_4 FST2m_1.2 FST2m_1.3 FST2m_1.4 FST2m_2.3 FST2m_2.4 FST2m_3.4 FST2v_1.2 FST2v_1.3 FST2v_1.4 FST2v_2.3 FST2v_2.4 FST2v_3.4 NEIm_1.2 NEIm_1.3 NEIm_1.4 NEIm_2.3 NEIm_2.4 NEIm_3.4 NEIv_1.2 NEIv_1.3 NEIv_1.4 NEIv_2.3 NEIv_2.4 NEIv_3.4 AMLm_1.2.3 AMLm_2.1.3 AMLm_3.1.2 AMLm_1.2.4 AMLm_2.1.4 AMLm_4.1.2 AMLm_1.3.4 AMLm_3.1.4 AMLm_4.1.3 AMLm_2.3.4 AMLm_3.2.4 AMLm_4.2.3 AMLv_1.2.3 AMLv_2.1.3 AMLv_3.1.2 AMLv_1.2.4 AMLv_2.1.4 AMLv_4.1.2 AMLv_1.3.4 AMLv_3.1.4 AMLv_4.1.3 AMLv_2.3.4 AMLv_3.2.4 AMLv_4.2.3 FST3m_1.2.3 FST3m_1.2.4 FST3m_1.3.4 FST3m_2.3.4 FST3v_1.2.3 FST3v_1.2.4 FST3v_1.3.4 FST3v_2.3.4 FST4m_1.2.3.4 FST4v_1.2.3.4 F3m_1.2.3 F3m_2.1.3 F3m_3.1.2 F3m_1.2.4 F3m_2.1.4 F3m_4.1.2 F3m_1.3.4 F3m_3.1.4 F3m_4.1.3 F3m_2.3.4 F3m_3.2.4 F3m_4.2.3 F3v_1.2.3 F3v_2.1.3 F3v_3.1.2 F3v_1.2.4 F3v_2.1.4 F3v_4.1.2 F3v_1.3.4 F3v_3.1.4 F3v_4.1.3 F3v_2.3.4 F3v_3.2.4 F3v_4.2.3 F4m_1.2.3.4 F4m_1.3.2.4 F4m_1.4.2.3 F4v_1.2.3.4 F4v_1.3.2.4 F4v_1.4.2.3

0.16040000 0.20900000 0.17140000 0.13560000 0.04760000 0.04040000 0.04800000 0.08960000 0.05320000 0.06220000 0.01180000 0.00980000 0.01740000 0.02820000 0.24573415 0.23733326 0.24383954 0.24676333 0.03115738 0.03385790 0.03262390 0.03103325 0.26197939 0.26098241 0.25774253 0.25476021 0.25756725 0.25379183 0.02926918 0.02846975 0.02784482 0.03051622 0.02969639 0.02902693 0.04681769 0.07940407 0.05416672 0.04282559 0.46879378 0.50942572 0.49085903 0.46692610 0.07836146 0.06225006 0.04458234 0.05549477 0.06025723 0.03358946 0.00945500 0.00778649 0.00575185 0.00717580 0.00750894 0.00455083 0.03503132 0.02976547 0.02398338 0.02749575 0.02921387 0.02048164 0.00441123 0.00338649 0.00230662 0.00282876 0.00331057 0.00172847 0.46100068 0.43007067 0.46660871 0.41585252 0.45461734 0.53559965 0.44427018 0.41160874 0.46579589 0.52307069 0.45610501 0.42763848 0.20082501 0.19845348 0.18797401 0.19051696 0.20904682 0.18271819 0.20296278 0.19137962 0.18182854 0.20989997 0.19033713 0.18543345 0.06555334 0.06133700 0.04699794 0.04994778 0.00546414 0.00500669 0.00396932 0.00426342 0.05605008 0.00364899 0.01123372 0.00921197 0.00496184 0.00821026 0.01223543 0.00328353 0.00959948 0.00659609 0.00189431 0.01060119 0.00357263 0.00491777 0.00118795 0.00098739 0.00071405 0.00090522 0.00111591 0.00065685 0.00086428 0.00061343 0.00046779 0.00091743 0.00050876 0.00064362 -0.00302346 -0.00163424 0.00138922 0.00057110 0.00065931 0.00050151
0.16040000 0.20900000 0.16140000 0.13560000 0.01760000 0.04040000 0.04800000 0.08960000 0.05320000 0.06220000 0.01180000 0.00980000 0.01740000 0.02820000 0.24573415 0.23733326 0.24383954 0.24676333 0.03115738 0.03385790 0.03262390 0.03103325 0.26197939 0.26098241 0.25774253 0.25476021 0.25756725 0.25379183 0.02926918 0.02846975 0.02784482 0.03051622 0.02969639 0.02902693 0.04681769 0.07940407 0.05416672 0.04282559 0.46879378 0.50942572 0.49085903 0.46692610 0.07836146 0.06225006 0.04458234 0.05549477 0.06025723 0.03358946 0.00945500 0.00778649 0.00575185 0.00717580 0.00750894 0.00455083 0.03503132 0.02976547 0.02398338 0.02749575 0.02921387 0.02048164 0.00441123 0.00338649 0.00230662 0.00282876 0.00331057 0.00172847 0.46100068 0.43007067 0.46660871 0.41585252 0.45461734 0.53559965 0.44427018 0.41160874 0.46579589 0.52307069 0.45610501 0.42763848 0.20082501 0.19845348 0.18797401 0.19051696 0.20904682 0.18271819 0.20296278 0.19137962 0.18182854 0.20989997 0.19033713 0.18543345 0.06555334 0.06133700 0.04699794 0.04994778 0.00546414 0.00500669 0.00396932 0.00426342 0.05605008 0.00364899 0.01123372 0.00921197 0.00496184 0.00821026 0.01223543 0.00328353 0.00959948 0.00659609 0.00189431 0.01060119 0.00357263 0.00491777 0.00118795 0.00098739 0.00071405 0.00090522 0.00111591 0.00065685 0.00086428 0.00061343 0.00046779 0.00091743 0.00050876 0.00064362 -0.00302346 -0.00163424 0.00138922 0.00057110 0.00065931 0.00040151
2 changes: 1 addition & 1 deletion test/forestmodelchoice-ks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ TEST_CASE("ModelChoice KS test")
bar.progress(i,nrun);

auto res = ModelChoice_fun(myread,statobs,opts,true);
postprobas[i] = res.post_proba;
postprobas[i] = res.post_proba[0];
}
std::cout << std::endl;
double D, pvalue;
Expand Down
14 changes: 14 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,27 @@ def test_modelchoice(path):
os.remove(filePath)
subprocess.run(path)

def test_modelchoice_multi(path):
"""Run basic multi target Model choice example
"""
for filePath in glob.glob('modelchoice_out.*'):
os.remove(filePath)
subprocess.call([path,"-b","statobsRF2.txt"])

def test_estimparam(path):
"""Run basic Parameter estimation example
"""
for filePath in glob.glob('estimparam_out.*'):
os.remove(filePath)
subprocess.call([path,"--parameter","ra","--chosenscen","3","--noob","50"])

def test_estimparam_multi(path):
"""Run basic multi target Parameter estimation example
"""
for filePath in glob.glob('estimparam_out.*'):
os.remove(filePath)
subprocess.call([path,"-b","statobsRF2.txt","--parameter","ra","--chosenscen","3","--noob","50"])

def test_parallel(path):
"""Check multithreaded performance
"""
Expand Down

0 comments on commit 9b4adff

Please sign in to comment.