Skip to content

Commit

Permalink
NUP-2339: Add capnproto serialization to SVM
Browse files Browse the repository at this point in the history
  • Loading branch information
lscheinkman committed Jun 6, 2017
1 parent 49bf1d2 commit cf03aa5
Show file tree
Hide file tree
Showing 5 changed files with 751 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ set(src_capnp_specs_rel
nupic/proto/SparseMatrixProto.capnp
nupic/proto/SpatialPoolerProto.capnp
nupic/proto/SdrClassifier.capnp
nupic/proto/SvmProto.capnp
nupic/proto/TemporalMemoryProto.capnp
nupic/proto/TestNodeProto.capnp
nupic/proto/VectorFileSensorProto.capnp
Expand Down Expand Up @@ -525,6 +526,7 @@ add_executable(${src_executable_gtests}
test/unit/algorithms/SDRClassifierTest.cpp
test/unit/algorithms/SegmentTest.cpp
test/unit/algorithms/SpatialPoolerTest.cpp
test/unit/algorithms/SvmTest.cpp
test/unit/algorithms/TemporalMemoryTest.cpp
test/unit/encoders/ScalarEncoderTest.cpp
test/unit/engine/InputTest.cpp
Expand Down
337 changes: 337 additions & 0 deletions src/nupic/algorithms/Svm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,48 @@ void svm_parameter::load(std::istream &inStream) {
shrinking >> weight_label >> weight;
}

//------------------------------------------------------------------------------
void svm_parameter::read(SvmParameterProto::Reader &proto) {
kernel = proto.getKernel();
probability = proto.getProbability();
gamma = proto.getGamma();
C = proto.getC();
eps = proto.getEps();
cache_size = proto.getCacheSize();
shrinking = proto.getShrinking();

auto weightList = proto.getWeight();
weight.resize(weightList.size());
std::copy(weightList.begin(), weightList.end(), weight.begin());

auto labelList = proto.getWeightLabel();
weight_label.resize(labelList.size());
std::copy(labelList.begin(), labelList.begin(), weight_label.begin());
}

//------------------------------------------------------------------------------
void svm_parameter::write(SvmParameterProto::Builder &proto) const {
proto.setKernel(kernel);
proto.setProbability(probability);
proto.setGamma(gamma);
proto.setC(C);
proto.setEps(eps);
proto.setCacheSize(cache_size);
proto.setShrinking(shrinking);

size_t size = weight.size();
auto weightList = proto.initWeight(size);
for (size_t i = 0; i < size; i++) {
weightList.set(i, weight[i]);
}

size = weight_label.size();
auto labelList = proto.initWeightLabel(size);
for (size_t i = 0; i < size; i++) {
labelList.set(i, weight_label[i]);
}
}

//------------------------------------------------------------------------------
int svm_problem::persistent_size() const {
stringstream b;
Expand Down Expand Up @@ -146,6 +188,47 @@ void svm_problem::load(std::istream &inStream) {
}
}

//------------------------------------------------------------------------------
void svm_problem::read(SvmProblemProto::Reader &proto) {
recover_ = proto.getRecover();
n_dims_ = proto.getNDims();

auto yList = proto.getY();
y_.resize(yList.size());
std::copy(yList.begin(), yList.end(), y_.begin());

for (auto &elem : x_)
delete[] elem;

x_.clear();
for (auto list : proto.getX()) {
float *values = new float[list.size()];
std::copy(list.begin(), list.end(), values);
x_.push_back(values);
}
}

//------------------------------------------------------------------------------
void svm_problem::write(SvmProblemProto::Builder &proto) const {
proto.setRecover(recover_);
proto.setNDims(n_dims_);

size_t size = y_.size();
auto yList = proto.initY(size);
for (size_t i = 0; i < size; i++) {
yList.set(i, y_[i]);
}

size = x_.size();
auto xList = proto.initX(size);
for (size_t i = 0; i < size; i++) {
auto dims = xList.init(i, n_dims_);
for (int j = 0; j < n_dims_; j++) {
dims.set(j, x_[i][j]);
}
}
}

//------------------------------------------------------------------------------
int svm_problem01::persistent_size() const {
stringstream b;
Expand Down Expand Up @@ -196,6 +279,59 @@ void svm_problem01::load(std::istream &inStream) {
}
}

//------------------------------------------------------------------------------
void svm_problem01::read(SvmProblem01Proto::Reader &proto) {
recover_ = proto.getRecover();
n_dims_ = proto.getNDims();
threshold_ = proto.getThreshold();

auto yList = proto.getY();
y_.resize(yList.size());
std::copy(yList.begin(), yList.end(), y_.begin());

auto nnzList = proto.getNnz();
nnz_.resize(nnzList.size());
std::copy(nnzList.begin(), nnzList.end(), nnz_.begin());

for (auto &elem : x_)
delete[] elem;

x_.clear();
for (auto list : proto.getX()) {
int *values = new int[list.size()];
std::copy(list.begin(), list.end(), values);
x_.push_back(values);
}
}

//------------------------------------------------------------------------------
void svm_problem01::write(SvmProblem01Proto::Builder &proto) const {
proto.setRecover(recover_);
proto.setNDims(n_dims_);
proto.setThreshold(threshold_);

size_t size = y_.size();
auto yList = proto.initY(size);
for (size_t i = 0; i < size; i++) {
yList.set(i, y_[i]);
}

size = nnz_.size();
auto nnzList = proto.initNnz(size);
for (size_t i = 0; i < size; i++) {
nnzList.set(i, nnz_[i]);
}

size = x_.size();
auto xList = proto.initX(size);
for (size_t i = 0; i < size; i++) {
auto dims = xList.init(i, n_dims_);
for (int j = 0; j < n_dims_; j++) {
dims.set(j, x_[i][j]);
}
}
}

//------------------------------------------------------------------------------
svm_model::~svm_model() {
// in all cases, ownership of the mem for the sv is with svm_model
Expand Down Expand Up @@ -377,6 +513,207 @@ void svm_model::load(std::istream &inStream) {
}

//------------------------------------------------------------------------------
void svm_model::read(SvmModelProto::Reader &proto) {
n_dims_ = proto.getNDims();

if (sv_mem == nullptr) {
for (auto &elem : sv)
delete[] elem;
} else {
delete[] sv_mem;
sv_mem = nullptr;
}
sv.clear();

for (auto list : proto.getSv()) {
float *values = new float[list.size()];
std::copy(list.begin(), list.end(), values);
sv.push_back(values);
}

for (auto &elem : sv_coef)
delete[] elem;
sv_coef.clear();

for (auto list : proto.getSvCoef()) {
float *values = new float[list.size()];
std::copy(list.begin(), list.end(), values);
sv_coef.push_back(values);
}

auto wList = proto.getW();
size_t size = wList.size();
w.resize(size);
for (size_t i = 0; i < size; i++) {
auto values = wList[i];
w[i].resize(values.size());
std::copy(values.begin(), values.end(), w[i].begin());
}

auto rhoList = proto.getRho();
rho.resize(rhoList.size());
std::copy(rhoList.begin(), rhoList.end(), rho.begin());

auto probAList = proto.getProbA();
probA.resize(probAList.size());
std::copy(probAList.begin(), probAList.end(), probA.begin());

auto probBList = proto.getProbB();
probB.resize(probBList.size());
std::copy(probBList.begin(), probBList.end(), probB.begin());

auto labelList = proto.getLabel();
label.resize(labelList.size());
std::copy(labelList.begin(), labelList.end(), label.begin());

auto nsvList = proto.getNSv();
n_sv.resize(nsvList.size());
std::copy(nsvList.begin(), nsvList.end(), n_sv.begin());
}

//------------------------------------------------------------------------------
void svm_model::write(SvmModelProto::Builder &proto) const {

proto.setNDims(n_dims_);

size_t size = sv.size();
auto svList = proto.initSv(size);
for (size_t i = 0; i < size; i++) {
auto dims = svList.init(i, n_dims_);
for (int j = 0; j < n_dims_; j++) {
dims.set(j, sv[i][j]);
}
}

size = sv_coef.size();
auto svCoefList = proto.initSvCoef(size);
for (size_t i = 0; i < size; i++) {
auto dims = svCoefList.init(i, n_dims_);
for (int j = 0; j < n_dims_; j++) {
dims.set(j, sv_coef[i][j]);
}
}

size = rho.size();
auto rhoList = proto.initRho(size);
for (size_t i = 0; i < size; i++) {
rhoList.set(i, rho[i]);
}

size = label.size();
auto labelList = proto.initLabel(size);
for (size_t i = 0; i < size; i++) {
labelList.set(i, label[i]);
}

size = n_sv.size();
auto nsvList = proto.initNSv(size);
for (size_t i = 0; i < size; i++) {
nsvList.set(i, n_sv[i]);
}

size = probA.size();
auto probAList = proto.initProbA(size);
for (size_t i = 0; i < size; i++) {
probAList.set(i, probA[i]);
}

size = probB.size();
auto probBList = proto.initProbB(size);
for (size_t i = 0; i < size; i++) {
probBList.set(i, probB[i]);
}

size = w.size();
auto wList = proto.initW(size);
for (size_t i = 0; i < size; i++) {
size_t len = w[i].size();
auto dims = wList.init(i, len);
for (size_t j = 0; j < len; j++) {
dims.set(j, w[i][j]);
}
}
}

//------------------------------------------------------------------------------
void svm_dense::write(SvmDenseProto::Builder &proto) const {
auto paramProto = proto.getParam();
svm_.param_.write(paramProto);

if (svm_.model_) {
auto modelProto = proto.getModel();
svm_.model_->write(modelProto);
}
if (svm_.problem_) {
auto problemProto = proto.getProblem();
svm_.problem_->write(problemProto);
}
}

//------------------------------------------------------------------------------
void svm_dense::read(SvmDenseProto::Reader &proto) {
auto paramProto = proto.getParam();
svm_.param_.read(paramProto);

if (svm_.model_) {
delete svm_.model_;
svm_.model_ = nullptr;
}
if (proto.hasModel()) {
auto modelProto = proto.getModel();
svm_.model_ = new svm_model;
svm_.model_->read(modelProto);
}
if (svm_.problem_) {
delete svm_.problem_;
svm_.problem_ = nullptr;
}
if (proto.hasProblem()) {
auto problemProto = proto.getProblem();
svm_.problem_ = new svm_problem(1, false);
svm_.problem_->read(problemProto);
}
}

//------------------------------------------------------------------------------
void svm_01::write(Svm01Proto::Builder &proto) const {
auto paramProto = proto.getParam();
svm_.param_.write(paramProto);

if (svm_.model_) {
auto modelProto = proto.getModel();
svm_.model_->write(modelProto);
}
if (svm_.problem_) {
auto problemProto = proto.getProblem();
svm_.problem_->write(problemProto);
}
}

//------------------------------------------------------------------------------
void svm_01::read(Svm01Proto::Reader &proto) {
auto paramProto = proto.getParam();
svm_.param_.read(paramProto);

if (svm_.model_) {
delete svm_.model_;
svm_.model_ = nullptr;
}
if (proto.hasModel()) {
auto modelProto = proto.getModel();
svm_.model_ = new svm_model;
svm_.model_->read(modelProto);
}
if (svm_.problem_) {
delete svm_.problem_;
svm_.problem_ = nullptr;
}
if (proto.hasProblem()) {
auto problemProto = proto.getProblem();
svm_.problem_ = new svm_problem01(1, false);
svm_.problem_->read(problemProto);
}
}
} // namespace svm
} // namespace algorithms
} // namespace nupic
Loading

0 comments on commit cf03aa5

Please sign in to comment.