diff --git a/python/caffe/draw.py b/python/caffe/draw.py index e4fd7aacce7..9eecf6d7b46 100644 --- a/python/caffe/draw.py +++ b/python/caffe/draw.py @@ -104,11 +104,11 @@ def get_layer_label(layer, rankdir): pooling_types_dict[layer.pooling_param.pool], layer.type, separator, - layer.pooling_param.kernel_size[0] if len(layer.pooling_param.kernel_size._values) else 1, + layer.pooling_param.kernel_size, separator, - layer.pooling_param.stride[0] if len(layer.pooling_param.stride._values) else 1, + layer.pooling_param.stride, separator, - layer.pooling_param.pad[0] if len(layer.pooling_param.pad._values) else 0) + layer.pooling_param.pad) else: node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type) return node_label diff --git a/python/caffe/test/test_draw.py b/python/caffe/test/test_draw.py new file mode 100644 index 00000000000..1634145ee9d --- /dev/null +++ b/python/caffe/test/test_draw.py @@ -0,0 +1,33 @@ +import os +import unittest + +from google import protobuf + +import caffe.draw +from caffe.proto import caffe_pb2 + +def getFilenames(): + """Yields files in the source tree which are Net prototxts.""" + result = [] + + root_dir = os.path.abspath(os.path.join( + os.path.dirname(__file__), '..', '..', '..')) + assert os.path.exists(root_dir) + + for dirname in ('models', 'examples'): + dirname = os.path.join(root_dir, dirname) + assert os.path.exists(dirname) + for cwd, _, filenames in os.walk(dirname): + for filename in filenames: + filename = os.path.join(cwd, filename) + if filename.endswith('.prototxt') and 'solver' not in filename: + yield os.path.join(dirname, filename) + + +class TestDraw(unittest.TestCase): + def test_draw_net(self): + for filename in getFilenames(): + net = caffe_pb2.NetParameter() + with open(filename) as infile: + protobuf.text_format.Merge(infile.read(), net) + caffe.draw.draw_net(net, 'LR') diff --git a/scripts/travis/install-deps.sh b/scripts/travis/install-deps.sh index 1593ed8b59a..2fa2a74a486 100755 --- a/scripts/travis/install-deps.sh +++ b/scripts/travis/install-deps.sh @@ -8,6 +8,7 @@ source $BASEDIR/defaults.sh apt-get -y update apt-get install -y --no-install-recommends \ build-essential \ + graphviz \ libboost-filesystem-dev \ libboost-python-dev \ libboost-system-dev \ @@ -31,6 +32,7 @@ if ! $WITH_PYTHON3 ; then python-dev \ python-numpy \ python-protobuf \ + python-pydot \ python-skimage else # Python3 diff --git a/scripts/travis/install-python-deps.sh b/scripts/travis/install-python-deps.sh index eeec302791f..910d35a93be 100755 --- a/scripts/travis/install-python-deps.sh +++ b/scripts/travis/install-python-deps.sh @@ -11,4 +11,5 @@ if ! $WITH_PYTHON3 ; then else # Python3 pip install --pre protobuf==3.0.0b3 + pip install pydot fi