Skip to content

Commit

Permalink
[pycaffe] re-expose Blob
Browse files Browse the repository at this point in the history
  • Loading branch information
longjon committed Feb 17, 2015
1 parent 74a584c commit 8a510f7
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace caffe {

// For Python, for now, we'll just always use float as the type.
typedef float Dtype;
const int NPY_DTYPE = NPY_FLOAT32;

void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
Expand Down Expand Up @@ -118,6 +119,44 @@ void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
PyArray_DIMS(data_arr)[0]);
}

struct NdarrayConverterGenerator {
template <typename T> struct apply;
};

template <>
struct NdarrayConverterGenerator::apply<Dtype*> {
struct type {
PyObject* operator() (Dtype* data) const {
// Just store the data pointer, and add the shape information in postcall.
return PyArray_SimpleNewFromData(0, NULL, NPY_DTYPE, data);
}
const PyTypeObject* get_pytype() {
return &PyArray_Type;
}
};
};

struct NdarrayCallPolicies : public bp::default_call_policies {
typedef NdarrayConverterGenerator result_converter;
PyObject* postcall(PyObject* pyargs, PyObject* result) {
bp::object pyblob = bp::extract<bp::tuple>(pyargs)()[0];
shared_ptr<Blob<Dtype> > blob =
bp::extract<shared_ptr<Blob<Dtype> > >(pyblob);
// Free the temporary pointer-holding array, and construct a new one with
// the shape information from the blob.
void* data = PyArray_DATA(reinterpret_cast<PyArrayObject*>(result));
Py_DECREF(result);
npy_intp dims[] = {blob->num(), blob->channels(),
blob->height(), blob->width()};
PyObject* arr_obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, data);
// SetBaseObject steals a ref, so we need to INCREF.
Py_INCREF(pyblob.ptr());
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(arr_obj),
pyblob.ptr());
return arr_obj;
}
};

BOOST_PYTHON_MODULE(_caffe) {
// below, we prepend an underscore to methods that will be replaced
// in Python
Expand Down Expand Up @@ -155,6 +194,19 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
.def("save", &Net_Save);

bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
"Blob", bp::no_init)
.add_property("num", &Blob<Dtype>::num)
.add_property("channels", &Blob<Dtype>::channels)
.add_property("height", &Blob<Dtype>::height)
.add_property("width", &Blob<Dtype>::width)
.add_property("count", &Blob<Dtype>::count)
.def("reshape", &Blob<Dtype>::Reshape)
.add_property("data", bp::make_function(&Blob<Dtype>::mutable_cpu_data,
NdarrayCallPolicies()))
.add_property("diff", bp::make_function(&Blob<Dtype>::mutable_cpu_diff,
NdarrayCallPolicies()));

// vector wrappers for all the vector types we use
bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
.def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());
Expand Down

0 comments on commit 8a510f7

Please sign in to comment.