Skip to content

Commit

Permalink
Merge pull request numenta#490 from htm-community/finish-cereal
Browse files Browse the repository at this point in the history
Finish Cereal serialization
  • Loading branch information
dkeeney authored Jun 8, 2019
2 parents 5b0b4a9 + 3d204cc commit 8234bdd
Show file tree
Hide file tree
Showing 71 changed files with 563 additions and 2,078 deletions.
4 changes: 2 additions & 2 deletions bindings/py/cpp_src/bindings/algorithms/py_TemporalMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Resets sequence state of the TM.)");
py_HTM.def("getActiveCells", [](const HTM_t& self)
{
auto dims = self.getColumnDimensions();
dims.push_back( self.getCellsPerColumn() );
dims.push_back( static_cast<UInt32>(self.getCellsPerColumn()) );
SDR *cells = new SDR( dims );
self.getActiveCells(*cells);
return cells;
Expand Down Expand Up @@ -274,7 +274,7 @@ R"()");
py_HTM.def("getWinnerCells", [](const HTM_t& self)
{
auto dims = self.getColumnDimensions();
dims.push_back( self.getCellsPerColumn() );
dims.push_back( static_cast<UInt32>(self.getCellsPerColumn()) );
SDR *winnerCells = new SDR( dims );
self.getWinnerCells(*winnerCells);
return winnerCells;
Expand Down
11 changes: 4 additions & 7 deletions bindings/py/cpp_src/bindings/engine/py_Engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,10 @@ namespace nupic_ext

py_Network.def("initialize", &nupic::Network::initialize);

py_Network.def("addRegionFromBundle", &nupic::Network::addRegionFromBundle
, "A function to load a serialized region into a Network framework."
, py::arg("name")
, py::arg("nodeType")
, py::arg("dimensions")
, py::arg("filename")
, py::arg("label") = "");
py_Network.def("save", &nupic::Network::save)
.def("load", &nupic::Network::load)
.def("saveToFile", &nupic::Network::saveToFile)
.def("loadFromFile", &nupic::Network::loadFromFile);

py_Network.def("link", &nupic::Network::link
, "Defines a link between regions"
Expand Down
29 changes: 26 additions & 3 deletions bindings/py/cpp_src/bindings/math/py_Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,32 @@ namespace nupic_ext {
//////////////////
// serialization
/////////////////
Random.def("saveToFile", [](Random_t& self, const std::string& name, int fmt) {
nupic::SerializableFormat fmt1;
switch(fmt) {
case 0: fmt1 = nupic::SerializableFormat::BINARY; break;
case 1: fmt1 = nupic::SerializableFormat::PORTABLE; break;
case 2: fmt1 = nupic::SerializableFormat::JSON; break;
case 3: fmt1 = nupic::SerializableFormat::XML; break;
default: NTA_THROW << "unknown serialization format.";
}
self.saveToFile(name, fmt1);
}, "serialize to a File, using BINARY=0, PORTABLE=1, JSON=2, or XML=3 format.",
py::arg("name"), py::arg("fmt") = 0);

Random.def("loadFromFile", [](Random_t& self, const std::string& name, int fmt) {
nupic::SerializableFormat fmt1;
switch(fmt) {
case 0: fmt1 = nupic::SerializableFormat::BINARY; break;
case 1: fmt1 = nupic::SerializableFormat::PORTABLE; break;
case 2: fmt1 = nupic::SerializableFormat::JSON; break;
case 3: fmt1 = nupic::SerializableFormat::XML; break;
default: NTA_THROW << "unknown serialization format.";
}
self.loadFromFile(name, fmt1);
}, "load from a File, using BINARY, PORTABLE, JSON, or XML format.",
py::arg("name"), py::arg("fmt") = 0);

Random.def(py::pickle(
[](const Random_t& r)
{
Expand All @@ -131,9 +157,6 @@ namespace nupic_ext {
}
));

Random.def("saveToFile", &Random_t::saveToFile);
Random.def("loadFromFile", &Random_t::loadFromFile);

}

} // namespace nupic_ext
146 changes: 53 additions & 93 deletions bindings/py/cpp_src/plugin/PyBindRegion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ In this case, the C++ engine is actually calling into the Python code.
#include <nupic/engine/Output.hpp>
#include <nupic/ntypes/Array.hpp>
#include <nupic/ntypes/BasicType.hpp>
#include <nupic/ntypes/BundleIO.hpp>
#include <nupic/utils/Log.hpp>
#include <nupic/os/Path.hpp>

Expand Down Expand Up @@ -154,113 +153,84 @@ namespace py = pybind11;

}

PyBindRegion::PyBindRegion(const char* module, BundleIO& bundle, Region * region, const char* className)
PyBindRegion::PyBindRegion(const char* module, ArWrapper& wrapper, Region * region, const char* className)
: RegionImpl(region)
, module_(module)
, className_(className)

{

deserialize(bundle);
// XXX ADD CHECK TO MAKE SURE THE TYPE MATCHES!
cereal_adapter_load(wrapper);
}

PyBindRegion::~PyBindRegion()
{
}

void PyBindRegion::serialize(BundleIO& bundle)
std::string PyBindRegion::pickleSerialize() const
{
// 1. serialize main state using pickle
// 2. call class method to serialize external state

// 1. Serialize main state of the Python module
// We want this to end up in the open stream obtained from bundle.
// a. We first pickle the python into a temporary file.
// b. copy the file into our open stream.

std::ifstream src;
std::streambuf* pbuf;
std::string tmp_pickle = "pickle.tmp";
py::tuple args = py::make_tuple(tmp_pickle, "wb");
auto f = py::module::import("__builtin__").attr("file")(*args);
// We want this to end up in the open stream obtained from bundle.
// a. We first pickle the python into a temporary file.
// b. copy the file into our open stream.

std::string tmp_pickle = "pickle.tmp";
py::tuple args = py::make_tuple(tmp_pickle, "wb");
auto f = py::module::import("__builtin__").attr("file")(*args);

#if PY_MAJOR_VERSION >= 3
auto pickle = py::module::import("pickle");
auto pickle = py::module::import("pickle");
#else
auto pickle = py::module::import("cPickle");
auto pickle = py::module::import("cPickle");
#endif
args = py::make_tuple(node_, f, 2); // use type 2 protocol
pickle.attr("dump")(*args);
pickle.attr("close")();

// get the out stream
std::ostream & out = bundle.getOutputStream();

// copy the pickle into the out stream
src.open(tmp_pickle.c_str(), std::ios::binary);
pbuf = src.rdbuf();
size_t size = pbuf->pubseekoff(0, src.end, src.in);
out << "Pickle " << size << std::endl;
pbuf->pubseekpos(0, src.in);
out << pbuf;
src.close();
out << "endPickle" << std::endl;
Path::remove(tmp_pickle);
args = py::make_tuple(node_, f, 2); // use type 2 protocol
pickle.attr("dump")(*args);
pickle.attr("close")();

// copy the pickle into the out string
std::ifstream pfile(tmp_pickle.c_str(), std::ios::binary);
std::string content((std::istreambuf_iterator<char>(pfile)),
std::istreambuf_iterator<char>());
pfile.close();
Path::remove(tmp_pickle);
return content;
}
std::string PyBindRegion::extraSerialize() const
{
std::string tmp_extra = "extra.tmp";

// 2. External state
// Call the Python serializeExtraData() method to write additional data.

args = py::make_tuple(tmp_pickle);
py::tuple args = py::make_tuple(tmp_extra);
// Need to put the None result in py::Ptr to decrement the ref count
node_.attr("serializeExtraData")(*args);

// copy the extra data into the out stream
src.open(tmp_pickle.c_str(), std::ios::binary);
pbuf = src.rdbuf();
size = pbuf->pubseekoff(0, src.end, src.in);
out << "ExtraData " << size << std::endl;
pbuf->pubseekpos(0, src.in);
out << pbuf;
src.close();
Path::remove(tmp_pickle);
out << "endExtraData" << std::endl;
Path::remove(tmp_pickle);
// copy the extra data into the extra string
std::ifstream efile(tmp_extra.c_str(), std::ios::binary);
std::string extra((std::istreambuf_iterator<char>(efile)),
std::istreambuf_iterator<char>());
efile.close();
Path::remove(tmp_extra);
return extra;

}

void PyBindRegion::deserialize(BundleIO& bundle)
{
void PyBindRegion::pickleDeserialize(std::string p) {
// 1. deserialize main state using pickle
// 2. call class method to deserialize external state

std::ofstream des;
std::streambuf *pbuf;
std::string tmp_pickle = "pickle.tmp";

// get the input stream
std::string tag;
size_t size;
char buf[10000];
std::istream & in = bundle.getInputStream();
in >> tag;
NTA_CHECK(tag == "Pickle") << "Deserialize error, expecting start of Pickle";
in >> size;
in.ignore(1);
pbuf = in.rdbuf();

// write the pickle part to pickle.tmp
des.open(tmp_pickle.c_str(), std::ios::binary);
while(size > 0) {
size_t len = (size >= sizeof(buf))?sizeof(buf): size;
pbuf->sgetn(buf, len);
des.write(buf, len);
size -= sizeof(buf);
}
des.close();
in >> tag;
NTA_CHECK(tag == "endPickle") << "Deserialize error, expected 'endPickle'\n";
in.ignore(1);
std::ofstream des;
std::string tmp_pickle = "pickle.tmp";


std::ofstream pfile(tmp_pickle.c_str(), std::ios::binary);
pfile.write(p.c_str(), p.size());
pfile.close();


// Tell Python to un-pickle using what is now in the pickle.tmp file.
py::args args = py::make_tuple(tmp_pickle, "rb");
Expand All @@ -276,30 +246,20 @@ namespace py = pybind11;
pickle.attr("load")(*args);

pickle.attr("close")();
Path::remove(tmp_pickle);
Path::remove(tmp_pickle);
}

void PyBindRegion::extraDeserialize(std::string e) {
// 2. External state
// fetch the extraData
in >> tag;
NTA_CHECK(tag == "ExtraData") << "Deserialize error, expected start of ExtraData\n";
in >> size;
in.ignore(1);
des.open(tmp_pickle.c_str(), std::ios::binary);
while(size > 0) {
size_t len = (size >= sizeof(buf))?sizeof(buf): size;
pbuf->sgetn(buf, len);
des.write(buf, len);
size -= sizeof(buf);
}
des.close();
in >> tag;
NTA_CHECK(tag == "endExtraData") << "Deserialize error, expected 'endExtraData'\n";
in.ignore(1);
std::string tmp_extra = "extra.tmp";
std::ofstream efile(tmp_extra.c_str(), std::ios::binary);
efile.write(e.c_str(), e.size());
efile.close();

// Call the Python deSerializeExtraData() method
args = py::make_tuple(tmp_pickle);
py::tuple args = py::make_tuple(tmp_extra);
node_.attr("deSerializeExtraData")(*args);
Path::remove(tmp_pickle);
Path::remove(tmp_extra);
}

template<typename T>
Expand Down
31 changes: 22 additions & 9 deletions bindings/py/cpp_src/plugin/PyBindRegion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ namespace nupic
// Constructors
PyBindRegion() = delete;
PyBindRegion(const char* module, const ValueMap& nodeParams, Region *region, const char* className);
PyBindRegion(const char* module, BundleIO& bundle, Region* region, const char* className);
PyBindRegion(const char* module, ArWrapper& wrapper, Region *region, const char* className) : RegionImpl(region) {
// TODO:cereal complete.
}
PyBindRegion(const char* module, ArWrapper& wrapper, Region *region, const char* className);

// no copy constructor
PyBindRegion(const Region &) = delete;
Expand All @@ -69,10 +66,22 @@ namespace nupic
virtual ~PyBindRegion();


// Manual serialization methods. Current recommended method.
void serialize(BundleIO& bundle) override;
void deserialize(BundleIO& bundle) override;

CerealAdapter; // see Serializable.hpp
// FOR Cereal Serialization
template<class Archive>
void save_ar(Archive& ar) const {
std::string p = pickleSerialize();
std::string e = extraSerialize();
ar(p, e);
}
template<class Archive>
void load_ar(Archive& ar) {
std::string p;
std::string e;
ar(p, e);
pickleDeserialize(p);
extraDeserialize(e);
}

bool operator==(const RegionImpl &other) const override {
NTA_THROW << " == not implemented yet for PyBindRegion.";
Expand Down Expand Up @@ -136,7 +145,11 @@ namespace nupic

Spec nodeSpec_; // locally cached version of spec.

};
std::string pickleSerialize() const;
std::string extraSerialize() const;
void pickleDeserialize(std::string p);
void extraDeserialize(std::string e);
};



Expand Down
20 changes: 0 additions & 20 deletions bindings/py/cpp_src/plugin/RegisteredRegionImplPy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ static int python_node_count = 0;
namespace nupic
{
class Spec;
class BundleIO;
class PyRegionImpl;
class Region;
class ValueMap;
Expand Down Expand Up @@ -119,25 +118,6 @@ namespace nupic
}
}

// use PyBindRegion class to instantiate and deserialize the python class in the specified module. TODO:cereal Remove
RegionImpl* deserializeRegionImpl(BundleIO& bundle, Region *region) override
{
try {
return new PyBindRegion(module_.c_str(), bundle, region, classname_.c_str());
}
catch (const py::error_already_set& e)
{
throw Exception(__FILE__, __LINE__, e.what());
}
catch (nupic::Exception & e)
{
throw nupic::Exception(e);
}
catch (...)
{
NTA_THROW << "Something bad happed while deserializing a .py region";
}
}
// use PyBindRegion class to instantiate and deserialize the python class in the specified module.
RegionImpl* deserializeRegionImpl(ArWrapper& wrapper, Region *region) override
{
Expand Down
1 change: 0 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ set(ntypes_files
nupic/ntypes/ArrayBase.hpp
nupic/ntypes/BasicType.cpp
nupic/ntypes/BasicType.hpp
nupic/ntypes/BundleIO.hpp
nupic/ntypes/Collection.hpp
nupic/ntypes/Dimensions.hpp
nupic/ntypes/Scalar.cpp
Expand Down
Loading

0 comments on commit 8234bdd

Please sign in to comment.