Skip to content

Commit

Permalink
Merge pull request #59 from beniz/unsupervised
Browse files Browse the repository at this point in the history
Support for unsupervised models + access to inner Caffe neural net's layers
  • Loading branch information
beniz committed Feb 16, 2016
2 parents 126a58a + 622a979 commit 4f6c58c
Show file tree
Hide file tree
Showing 9 changed files with 812 additions and 556 deletions.
11 changes: 11 additions & 0 deletions src/apidata.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ namespace dd
_data.insert(std::pair<std::string,ad_variant_type>(key,val));
}

/**
* \brief erase key / object from data object
* @param key string unique key
*/
inline void erase(const std::string &key)
{
auto hit = _data.begin();
if ((hit=_data.find(key))!=_data.end())
_data.erase(hit);
}

/**
* \brief get value from data object
* at this stage, type of value is unknown and the typed object
Expand Down
109 changes: 77 additions & 32 deletions src/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,10 @@ namespace dd
Caffe::set_mode(Caffe::CPU);
#endif

std::string extract_layer;
if (ad_mllib.has("extract_layer"))
extract_layer = ad_mllib.get("extract_layer").get<std::string>();

TInputConnectorStrategy inputc(this->_inputc);
APIData cad = ad;
cad.add("model_repo",this->_mlmodel._repo);
Expand All @@ -1747,45 +1751,83 @@ namespace dd
}

float loss = 0.0;
std::vector<Blob<float>*> results;

try
{
results = _net->ForwardPrefilled(&loss);
}
catch (std::exception &e)
{
delete _net; // because empirical analysis reveals that the net is left in an unusable state
_net = nullptr;
throw;
}
int slot = results.size() - 1;
if (_regression)
TOutputConnectorStrategy tout;
if (extract_layer.empty())
{
if (_ntargets > 1)
slot = 1;
else slot = 0; // XXX: more in-depth testing required
std::vector<Blob<float>*> results = _net->ForwardPrefilled(&loss);
int slot = results.size() - 1;
if (_regression)
{
if (_ntargets > 1)
slot = 1;
else slot = 0; // XXX: more in-depth testing required
}
int scount = results[slot]->count();
int scperel = scount / batch_size;
int nclasses = scperel;
std::vector<APIData> vrad;
for (int j=0;j<batch_size;j++)
{
APIData rad;
rad.add("uri",inputc._ids.at(j));
rad.add("loss",loss);
std::vector<double> probs;
std::vector<std::string> cats;
for (int i=0;i<nclasses;i++)
{
probs.push_back(results[slot]->cpu_data()[j*scperel+i]);
cats.push_back(this->_mlmodel.get_hcorresp(i));
}
rad.add("probs",probs);
rad.add("cats",cats);
vrad.push_back(rad);
}
tout.add_results(vrad);
if (_regression)
{
out.add("regression",true);
out.add("nclasses",nclasses);
}
}
int scount = results[slot]->count();
int scperel = scount / batch_size;
int nclasses = scperel;
TOutputConnectorStrategy tout;
for (int j=0;j<batch_size;j++)
else // unsupervised
{
tout.add_result(inputc._ids.at(j),loss);
for (int i=0;i<nclasses;i++)
tout.add_cat(inputc._ids.at(j),results[slot]->cpu_data()[j*scperel+i],this->_mlmodel.get_hcorresp(i));
std::map<std::string,int> n_layer_names_index = _net->layer_names_index();
std::map<std::string,int>::const_iterator lit;
if ((lit=n_layer_names_index.find(extract_layer))==n_layer_names_index.end())
throw MLLibBadParamException("unknown extract layer " + extract_layer);
int li = (*lit).second;
loss = _net->ForwardFromTo(0,li);
const std::vector<std::vector<Blob<float>*>>& rresults = _net->top_vecs();
std::vector<Blob<float>*> results = rresults.at(li);
std::vector<APIData> vrad;
int slot = 0;
int scount = results[slot]->count();
int scperel = scount / batch_size;
std::vector<int> vshape = {batch_size,scperel};
results[slot]->Reshape(vshape); // reshaping into a rectangle, first side = batch size
for (int j=0;j<batch_size;j++)
{
APIData rad;
rad.add("uri",inputc._ids.at(j));
rad.add("loss",loss);
std::vector<double> vals;
int cpos = 0;
for (int c=0;c<results.at(slot)->shape(1);c++)
{
vals.push_back(results.at(slot)->cpu_data()[j*scperel+cpos]);
++cpos;
}
rad.add("vals",vals);
vrad.push_back(rad);
}
tout.add_results(vrad);
}
TOutputConnectorStrategy btout(this->_outputc);
if (_regression)
tout._best = nclasses;
tout.best_cats(ad.getobj("parameters").getobj("output"),btout);
btout.to_ad(out,_regression);
tout.finalize(ad.getobj("parameters").getobj("output"),out);
out.add("status",0);

return 0;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
void CaffeLib<TInputConnectorStrategy,TOutputConnectorStrategy,TMLModel>::update_in_memory_net_and_solver(caffe::SolverParameter &sp,
const APIData &ad,
Expand Down Expand Up @@ -2043,8 +2085,11 @@ namespace dd
//debug
}
}

template class CaffeLib<ImgCaffeInputFileConn,SupervisedOutput,CaffeModel>;
template class CaffeLib<CSVCaffeInputFileConn,SupervisedOutput,CaffeModel>;
template class CaffeLib<TxtCaffeInputFileConn,SupervisedOutput,CaffeModel>;
template class CaffeLib<ImgCaffeInputFileConn,UnsupervisedOutput,CaffeModel>;
template class CaffeLib<CSVCaffeInputFileConn,UnsupervisedOutput,CaffeModel>;
template class CaffeLib<TxtCaffeInputFileConn,UnsupervisedOutput,CaffeModel>;
}
4 changes: 2 additions & 2 deletions src/caffelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ namespace dd
//TODO: status ?

/*- local functions -*/
/**
/**
* \brief test net
* @param ad root data object
* @param inputc input connector
Expand Down Expand Up @@ -203,7 +203,7 @@ namespace dd
int _ntargets = 0; /**< number of classification or regression targets. */
std::mutex _net_mutex; /**< mutex around net, e.g. no concurrent predict calls as net is not re-instantiated. Use batches instead. */
};

}

#endif
42 changes: 32 additions & 10 deletions src/jsonapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ namespace dd
// optional parameters.
if (d.HasMember("type"))
type = d["type"].GetString();
else type == "supervised"; // default
if (d.HasMember("description"))
description = d["description"].GetString();

Expand All @@ -294,15 +295,35 @@ namespace dd
if (mllib == "caffe")
{
CaffeModel cmodel(ad_model);
if (input == "image")
add_service(sname,std::move(MLService<CaffeLib,ImgCaffeInputFileConn,SupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else if (input == "csv")
add_service(sname,std::move(MLService<CaffeLib,CSVCaffeInputFileConn,SupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else if (input == "txt")
add_service(sname,std::move(MLService<CaffeLib,TxtCaffeInputFileConn,SupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else return dd_input_connector_not_found_1004();
if (JsonAPI::store_json_blob(cmodel._repo,jstr)) // store successful call json blob
LOG(ERROR) << "couldn't write " << JsonAPI::_json_blob_fname << " file in model repository " << cmodel._repo << std::endl;
if (type == "supervised")
{
if (input == "image")
add_service(sname,std::move(MLService<CaffeLib,ImgCaffeInputFileConn,SupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else if (input == "csv")
add_service(sname,std::move(MLService<CaffeLib,CSVCaffeInputFileConn,SupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else if (input == "txt")
add_service(sname,std::move(MLService<CaffeLib,TxtCaffeInputFileConn,SupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else return dd_input_connector_not_found_1004();
if (JsonAPI::store_json_blob(cmodel._repo,jstr)) // store successful call json blob
LOG(ERROR) << "couldn't write " << JsonAPI::_json_blob_fname << " file in model repository " << cmodel._repo << std::endl;
}
else if (type == "unsupervised")
{
if (input == "image")
add_service(sname,std::move(MLService<CaffeLib,ImgCaffeInputFileConn,UnsupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else if (input == "csv")
add_service(sname,std::move(MLService<CaffeLib,CSVCaffeInputFileConn,UnsupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else if (input == "txt")
add_service(sname,std::move(MLService<CaffeLib,TxtCaffeInputFileConn,UnsupervisedOutput,CaffeModel>(sname,cmodel,description)),ad);
else return dd_input_connector_not_found_1004();
if (JsonAPI::store_json_blob(cmodel._repo,jstr)) // store successful call json blob
LOG(ERROR) << "couldn't write " << JsonAPI::_json_blob_fname << " file in model repository " << cmodel._repo << std::endl;
}
else
{
// unknown service type
return dd_service_bad_request_1006();
}
}
else
{
Expand Down Expand Up @@ -468,7 +489,8 @@ namespace dd
jhead.AddMember("service",d["service"],jpred.GetAllocator());
jpred.AddMember("head",jhead,jpred.GetAllocator());
JVal jbody(rapidjson::kObjectType);
jbody.AddMember("predictions",jout["predictions"],jpred.GetAllocator());
if (jout.HasMember("predictions"))
jbody.AddMember("predictions",jout["predictions"],jpred.GetAllocator());
jpred.AddMember("body",jbody,jpred.GetAllocator());
if (ad_data.getobj("parameters").getobj("output").has("template"))
{
Expand Down
Loading

0 comments on commit 4f6c58c

Please sign in to comment.