-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added bindings for greedy module (#21)
* added bindings for greedy module * Update Greedy bindings and add tests * Update bindings and test for getNearestROM --------- Co-authored-by: Cole Kendrick <kendrick6@llnl.gov>
- Loading branch information
1 parent
78e9756
commit 35dca81
Showing
7 changed files
with
472 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# To add pure python routines to this module, | ||
# either define/import the python routine in this file. | ||
# This will combine both c++ bindings/pure python routines into this module. | ||
|
||
# For other c++ binding modules, change the module name accordingly. | ||
from _pylibROM.algo.greedy import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/operators.h> | ||
#include <pybind11/stl.h> | ||
#include "algo/greedy/GreedyCustomSampler.h" | ||
#include "linalg/Vector.h" | ||
|
||
namespace py = pybind11; | ||
using namespace CAROM; | ||
using namespace std; | ||
|
||
void init_GreedyCustomSampler(pybind11::module_ &m) { | ||
py::class_<GreedyCustomSampler, GreedySampler>(m, "GreedyCustomSampler") | ||
.def(py::init<std::vector<CAROM::Vector>, bool, double, double, double, int, int, std::string, std::string, bool, int, bool>(), | ||
py::arg("parameter_points"), | ||
py::arg("check_local_rom"), | ||
py::arg("relative_error_tolerance"), | ||
py::arg("alpha"), | ||
py::arg("max_clamp"), | ||
py::arg("subset_size"), | ||
py::arg("convergence_subset_size"), | ||
py::arg("output_log_path") = "", | ||
py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, | ||
py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false) | ||
.def(py::init<std::vector<double>, bool, double, double, double, int, int, std::string, std::string, bool, int, bool>(), | ||
py::arg("parameter_points"), | ||
py::arg("check_local_rom"), | ||
py::arg("relative_error_tolerance"), | ||
py::arg("alpha"), | ||
py::arg("max_clamp"), | ||
py::arg("subset_size"), | ||
py::arg("convergence_subset_size"), | ||
py::arg("output_log_path") = "", | ||
py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, | ||
py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false) | ||
.def(py::init<std::string, std::string>(), | ||
py::arg("base_file_name"), | ||
py::arg("output_log_path") = "") | ||
.def("__del__", [](GreedyCustomSampler& self){ self.~GreedyCustomSampler(); }); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/operators.h> | ||
#include <pybind11/stl.h> | ||
#include "algo/greedy/GreedyRandomSampler.h" | ||
#include "linalg/Vector.h" | ||
|
||
namespace py = pybind11; | ||
using namespace CAROM; | ||
using namespace std; | ||
|
||
void init_GreedyRandomSampler(py::module &m) { | ||
py::class_<GreedyRandomSampler, GreedySampler>(m, "GreedyRandomSampler") | ||
.def(py::init<CAROM::Vector, CAROM::Vector, int, bool, double, double,double, int, int, bool, std::string, std::string, bool, int, bool>(), | ||
py::arg("param_space_min"), | ||
py::arg("param_space_max"), | ||
py::arg("num_parameter_points"), | ||
py::arg("check_local_rom"), | ||
py::arg("relative_error_tolerance"), | ||
py::arg("alpha"), | ||
py::arg("max_clamp"), | ||
py::arg("subset_size"), | ||
py::arg("convergence_subset_size"), | ||
py::arg("use_latin_hypercube"), | ||
py::arg("output_log_path") = "", | ||
py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, | ||
py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false | ||
) | ||
.def(py::init<double, double, int, bool, double, double,double, int, int, bool, std::string, std::string, bool, int, bool>(), | ||
py::arg("param_space_min"), | ||
py::arg("param_space_max"), | ||
py::arg("num_parameter_points"), | ||
py::arg("check_local_rom"), | ||
py::arg("relative_error_tolerance"), | ||
py::arg("alpha"), | ||
py::arg("max_clamp"), | ||
py::arg("subset_size"), | ||
py::arg("convergence_subset_size"), | ||
py::arg("use_latin_hypercube"), | ||
py::arg("output_log_path") = "", | ||
py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, | ||
py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false | ||
) | ||
.def(py::init<std::string, std::string>(), | ||
py::arg("base_file_name"), | ||
py::arg("output_log_path") = "" | ||
) | ||
.def("save", &GreedyRandomSampler::save, py::arg("base_file_name")) | ||
.def("__del__", [](GreedyRandomSampler& self){ self.~GreedyRandomSampler(); }); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/operators.h> | ||
#include <pybind11/stl.h> | ||
#include "algo/greedy/GreedySampler.h" | ||
#include "linalg/Vector.h" | ||
|
||
namespace py = pybind11; | ||
using namespace CAROM; | ||
using namespace std; | ||
|
||
|
||
class PyGreedySampler : public GreedySampler { | ||
public: | ||
using GreedySampler::GreedySampler; | ||
|
||
void save(std::string base_file_name) override { | ||
PYBIND11_OVERRIDE(void,GreedySampler,save,base_file_name ); | ||
} | ||
protected: | ||
void constructParameterPoints() override { | ||
PYBIND11_OVERRIDE_PURE(void, GreedySampler, constructParameterPoints,); | ||
} | ||
void getNextParameterPointAfterConvergenceFailure() override { | ||
PYBIND11_OVERRIDE_PURE(void, GreedySampler, getNextParameterPointAfterConvergenceFailure,); | ||
} | ||
}; | ||
|
||
void init_GreedySampler(pybind11::module_ &m) { | ||
py::class_<GreedyErrorIndicatorPoint>(m, "GreedyErrorIndicatorPoint") | ||
.def_property_readonly("point", [](GreedyErrorIndicatorPoint &self) { | ||
return self.point.get(); | ||
}) | ||
.def_property_readonly("localROM", [](GreedyErrorIndicatorPoint &self) { | ||
return self.localROM.get(); | ||
}); | ||
|
||
py::class_<GreedySampler,PyGreedySampler>(m, "GreedySampler") | ||
.def(py::init<std::vector<Vector>, bool, double, double, double, int, int, std::string, std::string, bool, int, bool>(), | ||
py::arg("parameter_points"), | ||
py::arg("check_local_rom"), | ||
py::arg("relative_error_tolerance"), | ||
py::arg("alpha"), | ||
py::arg("max_clamp"), | ||
py::arg("subset_size"), | ||
py::arg("convergence_subset_size"), | ||
py::arg("output_log_path") = "", | ||
py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, | ||
py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false) | ||
.def(py::init<std::vector<double>,bool, double, double, double, int, int, std::string, std::string, bool, int, bool>(), | ||
py::arg("parameter_points"), | ||
py::arg("check_local_rom"), | ||
py::arg("relative_error_tolerance"), | ||
py::arg("alpha"), | ||
py::arg("max_clamp"), | ||
py::arg("subset_size"), | ||
py::arg("convergence_subset_size"), | ||
py::arg("output_log_path") = "", | ||
py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, | ||
py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false) | ||
.def(py::init<Vector, Vector, int, bool, double, double, double, int, int,std::string, std::string, bool, int, bool>(), | ||
py::arg("param_space_min"), py::arg("param_space_max"), py::arg("num_parameter_points"), | ||
py::arg("check_local_rom"), py::arg("relative_error_tolerance"), py::arg("alpha"), | ||
py::arg("max_clamp"), py::arg("subset_size"), py::arg("convergence_subset_size"), | ||
py::arg("output_log_path") = "", py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false | ||
) | ||
.def(py::init<double, double, int, bool, double, double, double, int, int,std::string, std::string, bool, int, bool>(), | ||
py::arg("param_space_min"), py::arg("param_space_max"), py::arg("num_parameter_points"), | ||
py::arg("check_local_rom"), py::arg("relative_error_tolerance"), py::arg("alpha"), | ||
py::arg("max_clamp"), py::arg("subset_size"), py::arg("convergence_subset_size"), | ||
py::arg("output_log_path") = "", py::arg("warm_start_file_name") = "", | ||
py::arg("use_centroid") = true, py::arg("random_seed") = 1, | ||
py::arg("debug_algorithm") = false | ||
) | ||
.def(py::init<std::string, std::string>(), py::arg("base_file_name"), py::arg("output_log_path") = "") | ||
.def("getNextParameterPoint", [](GreedySampler& self) -> std::unique_ptr<Vector> { | ||
std::shared_ptr<Vector> result = self.getNextParameterPoint(); | ||
return std::make_unique<Vector>(*(result.get())); | ||
}) | ||
.def("getNextPointRequiringRelativeError", [](GreedySampler& self) -> GreedyErrorIndicatorPoint { | ||
// Create a deepcopy of the struct, otherwise it will get freed twice | ||
GreedyErrorIndicatorPoint point = self.getNextPointRequiringRelativeError(); | ||
Vector *t_pt = nullptr; | ||
Vector *t_lROM = nullptr; | ||
|
||
if (point.point) | ||
{ | ||
t_pt = new Vector(*(point.point)); | ||
} | ||
|
||
if (point.localROM) | ||
{ | ||
t_lROM = new Vector(*(point.localROM)); | ||
} | ||
|
||
return createGreedyErrorIndicatorPoint(t_pt, t_lROM); | ||
}, py::return_value_policy::reference) | ||
.def("getNextPointRequiringErrorIndicator", [](GreedySampler& self) -> GreedyErrorIndicatorPoint { | ||
// Create a deepcopy of the struct, otherwise it will get freed twice | ||
GreedyErrorIndicatorPoint point = self.getNextPointRequiringErrorIndicator(); | ||
|
||
Vector *t_pt = nullptr; | ||
Vector *t_lROM = nullptr; | ||
|
||
if (point.point) | ||
{ | ||
t_pt = new Vector(*(point.point)); | ||
} | ||
|
||
if (point.localROM) | ||
{ | ||
t_lROM = new Vector(*(point.localROM)); | ||
} | ||
|
||
return createGreedyErrorIndicatorPoint(t_pt, t_lROM); | ||
}, py::return_value_policy::reference) | ||
.def("setPointRelativeError", (void (GreedySampler::*) (double))&GreedySampler::setPointRelativeError) | ||
.def("setPointErrorIndicator", (void (GreedySampler::*) (double,int)) &GreedySampler::setPointErrorIndicator) | ||
.def("getNearestNonSampledPoint", (int (GreedySampler::*) (CAROM::Vector)) &GreedySampler::getNearestNonSampledPoint) | ||
.def("getNearestROM", [](GreedySampler& self, Vector point) -> std::unique_ptr<Vector> { | ||
std::shared_ptr<Vector> result = self.getNearestROM(point); | ||
if (!result) | ||
{ | ||
return nullptr; | ||
} | ||
return std::make_unique<Vector>(*(result.get())); | ||
}) | ||
.def("getParameterPointDomain", &GreedySampler::getParameterPointDomain) | ||
.def("getSampledParameterPoints", &GreedySampler::getSampledParameterPoints) | ||
.def("save", &GreedySampler::save) | ||
.def("__del__", [](GreedySampler& self){ self.~GreedySampler(); }) | ||
.def("isComplete", &GreedySampler::isComplete); | ||
|
||
m.def("createGreedyErrorIndicatorPoint", [](Vector* point, Vector* localROM) { | ||
return createGreedyErrorIndicatorPoint(point, localROM); | ||
}); | ||
m.def("createGreedyErrorIndicatorPoint", [](Vector* point, std::shared_ptr<Vector>& localROM) { | ||
return createGreedyErrorIndicatorPoint(point, localROM); | ||
}); | ||
m.def("getNearestPoint", [](std::vector<Vector>& paramPoints,Vector point) { | ||
return getNearestPoint(paramPoints, point); | ||
}); | ||
m.def("getNearestPoint", [](std::vector<double>& paramPoints, double point) { | ||
return getNearestPoint(paramPoints, point); | ||
}); | ||
m.def("getNearestPointIndex", [](std::vector<Vector> paramPoints, Vector point) { | ||
return getNearestPointIndex(paramPoints, point); | ||
}); | ||
m.def("getNearestPointIndex", [](std::vector<double> paramPoints, double point) { | ||
return getNearestPointIndex(paramPoints, point); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.