diff --git a/Makefile b/Makefile
index cb396794e37..db0f531eaa0 100644
--- a/Makefile
+++ b/Makefile
@@ -171,6 +171,7 @@ WARNINGS := -Wall -Wno-sign-compare
# Set build directories
##############################
+DISTRIBUTE_DIR ?= distribute
DISTRIBUTE_SUBDIRS := $(DISTRIBUTE_DIR)/bin $(DISTRIBUTE_DIR)/lib
DIST_ALIASES := dist
ifneq ($(strip $(DISTRIBUTE_DIR)),distribute)
@@ -232,13 +233,15 @@ endif
# libstdc++ for NVCC compatibility on OS X >= 10.9 with CUDA < 7.0
ifeq ($(OSX), 1)
CXX := /usr/bin/clang++
- CUDA_VERSION := $(shell $(CUDA_DIR)/bin/nvcc -V | grep -o 'release \d' | grep -o '\d')
- ifeq ($(shell echo $(CUDA_VERSION) \< 7.0 | bc), 1)
- CXXFLAGS += -stdlib=libstdc++
- LINKFLAGS += -stdlib=libstdc++
+ ifneq ($(CPU_ONLY), 1)
+ CUDA_VERSION := $(shell $(CUDA_DIR)/bin/nvcc -V | grep -o 'release \d' | grep -o '\d')
+ ifeq ($(shell echo $(CUDA_VERSION) \< 7.0 | bc), 1)
+ CXXFLAGS += -stdlib=libstdc++
+ LINKFLAGS += -stdlib=libstdc++
+ endif
+ # clang throws this warning for cuda headers
+ WARNINGS += -Wno-unneeded-internal-declaration
endif
- # clang throws this warning for cuda headers
- WARNINGS += -Wno-unneeded-internal-declaration
# gtest needs to use its own tuple to not conflict with clang
COMMON_FLAGS += -DGTEST_USE_OWN_TR1_TUPLE=1
# boost::thread is called boost_thread-mt to mark multithreading on OS X
diff --git a/docs/development.md b/docs/development.md
index fe54864bd35..ccb6a29701d 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -30,7 +30,7 @@ Similarly for IPython notebooks: simply include `"include_in_docs": true` in the
Other docs, such as installation guides, are written in the `docs` directory and manually linked to from the `index.md` page.
-We strive to provide provide lots of usage examples, and to document all code in docstrings.
+We strive to provide lots of usage examples, and to document all code in docstrings.
We absolutely appreciate any contribution to this effort!
### Versioning
diff --git a/docs/install_apt.md b/docs/install_apt.md
index 89bc9a00aef..75f8bec0e95 100644
--- a/docs/install_apt.md
+++ b/docs/install_apt.md
@@ -8,12 +8,24 @@ title: Installation: Ubuntu
sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy-dev libopencv-dev libboost-all-dev libhdf5-serial-dev
+**CUDA**: Install via the NVIDIA package instead of `apt-get` to be certain of the library and driver versions.
+Install the library and latest driver separately; the driver bundled with the library is usually out-of-date.
+This can be skipped for CPU-only installation.
+
+**BLAS**: install ATLAS by `sudo apt-get install libatlas-base-dev` or install OpenBLAS or MKL for better CPU performance.
+
+**Python** (optional): if you use the default Python you will need to `sudo apt-get install` the `python-dev` package to have the Python headers for building the pycaffe interface.
+
**Remaining dependencies, 14.04**
+Everything is packaged in 14.04.
+
sudo apt-get install libgflags-dev libgoogle-glog-dev liblmdb-dev protobuf-compiler
**Remaining dependencies, 12.04**
+These dependencies need manual installation in 12.04.
+
# glog
wget https://google-glog.googlecode.com/files/glog-0.3.3.tar.gz
tar zxvf glog-0.3.3.tar.gz
@@ -28,17 +40,10 @@ title: Installation: Ubuntu
export CXXFLAGS="-fPIC" && cmake .. && make VERBOSE=1
make && make install
# lmdb
- git clone git://gitorious.org/mdb/mdb.git
+ git clone https://gitorious.org/mdb/mdb.git
cd mdb/libraries/liblmdb
make && make install
Note that glog does not compile with the most recent gflags version (2.1), so before that is resolved you will need to build with glog first.
-**CUDA**: Install via the NVIDIA package instead of `apt-get` to be certain of the library and driver versions.
-Install the library and latest driver separately; the driver bundled with the library is usually out-of-date.
-
-**BLAS**: install ATLAS by `sudo apt-get install libatlas-base-dev` or install OpenBLAS or MKL for better CPU performance.
-
-**Python** (optional): if you use the default Python you will need to `sudo apt-get install` the `python-dev` package to have the Python headers for building the pycaffe interface.
-
Continue with [compilation](installation.html#compilation).
diff --git a/docs/install_osx.md b/docs/install_osx.md
index 0373a417847..39cb02fe232 100644
--- a/docs/install_osx.md
+++ b/docs/install_osx.md
@@ -18,7 +18,7 @@ In other `ENV` settings, things may not work as expected.
brew install --fresh -vd snappy leveldb gflags glog szip lmdb
# need the homebrew science source for OpenCV and hdf5
brew tap homebrew/science
- hdf5 opencv
+ brew install hdf5 opencv
If using Anaconda Python, a modification to the OpenCV formula might be needed
Do `brew edit opencv` and change the lines that look like the two lines below to exactly the two lines below.
@@ -115,7 +115,7 @@ Then, whenever you want to update homebrew, switch back to the master branches,
# Update homebrew; hopefully this works without errors!
brew update
- # Switch back to the caffe branches with the forumlae that you modified earlier
+ # Switch back to the caffe branches with the formulae that you modified earlier
cd /usr/local
git rebase master caffe
# Fix any merge conflicts and commit to caffe branch
diff --git a/docs/model_zoo.md b/docs/model_zoo.md
index ad30d0acd55..06dc0a49ec7 100644
--- a/docs/model_zoo.md
+++ b/docs/model_zoo.md
@@ -3,28 +3,30 @@ title: Model Zoo
---
# Caffe Model Zoo
-Lots of people have used Caffe to train models of different architectures and applied to different problems, ranging from simple regression to AlexNet-alikes to Siamese networks for image similarity to speech applications.
-To lower the friction of sharing these models, we introduce the model zoo framework:
+Lots of researchers and engineers have made Caffe models for different tasks with all kinds of architectures and data.
+These models are learned and applied for problems ranging from simple regression, to large-scale visual classification, to Siamese networks for image similarity, to speech and robotics applications.
+
+To help share these models, we introduce the model zoo framework:
- A standard format for packaging Caffe model info.
-- Tools to upload/download model info to/from Github Gists, and to download trained `.caffemodel` parameters.
+- Tools to upload/download model info to/from Github Gists, and to download trained `.caffemodel` binaries.
- A central wiki page for sharing model info Gists.
-## BVLC Reference Models
+## Where to get trained models
-First of all, we provide some trained models out of the box.
+First of all, we bundle BVLC-trained models for unrestricted, out of the box use.
+
+See the [BVLC model license](#bvlc-model-license) for details.
Each one of these can be downloaded by running `scripts/download_model_binary.py ` where `` is specified below:
-- **BVLC Reference CaffeNet** in `models/bvlc_reference_caffenet`: AlexNet trained on ILSVRC 2012, with a minor variation from the version as described in the NIPS 2012 paper. (Trained by Jeff Donahue @jeffdonahue)
-- **BVLC AlexNet** in `models/bvlc_alexnet`: AlexNet trained on ILSVRC 2012, almost exactly as described in NIPS 2012. (Trained by Evan Shelhamer @shelhamer)
-- **BVLC Reference R-CNN ILSVRC-2013** in `models/bvlc_reference_rcnn_ilsvrc13`: pure Caffe implementation of [R-CNN](https://github.com/rbgirshick/rcnn). (Trained by Ross Girshick @rbgirshick)
-- **BVLC GoogleNet** in `models/bvlc_googlenet`: GoogleNet trained on ILSVRC 2012, almost exactly as described in [GoogleNet](http://arxiv.org/abs/1409.4842). (Trained by Sergio Guadarrama @sguada)
-
+- **BVLC Reference CaffeNet** in `models/bvlc_reference_caffenet`: AlexNet trained on ILSVRC 2012, with a minor variation from the version as described in [ImageNet classification with deep convolutional neural networks](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks) by Krizhevsky et al. in NIPS 2012. (Trained by Jeff Donahue @jeffdonahue)
+- **BVLC AlexNet** in `models/bvlc_alexnet`: AlexNet trained on ILSVRC 2012, almost exactly as described in [ImageNet classification with deep convolutional neural networks](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks) by Krizhevsky et al. in NIPS 2012. (Trained by Evan Shelhamer @shelhamer)
+- **BVLC Reference R-CNN ILSVRC-2013** in `models/bvlc_reference_rcnn_ilsvrc13`: pure Caffe implementation of [R-CNN](https://github.com/rbgirshick/rcnn) as described by Girshick et al. in CVPR 2014. (Trained by Ross Girshick @rbgirshick)
+- **BVLC GoogLeNet** in `models/bvlc_googlenet`: GoogLeNet trained on ILSVRC 2012, almost exactly as described in [Going Deeper with Convolutions](http://arxiv.org/abs/1409.4842) by Szegedy et al. in ILSVRC 2014. (Trained by Sergio Guadarrama @sguada)
-## Community Models
-
-The publicly-editable [Caffe Model Zoo wiki](https://github.com/BVLC/caffe/wiki/Model-Zoo) catalogues user-made models.
-Refer to the model details for authorship and conditions -- please respect licenses and citations.
+**Community models** made by Caffe users are posted to a publicly editable [wiki page](https://github.com/BVLC/caffe/wiki/Model-Zoo).
+These models are subject to conditions of their respective authors such as citation and license.
+Thank you for sharing your models!
## Model info format
@@ -44,7 +46,7 @@ A caffe model is distributed as a directory containing:
Github Gist is a good format for model info distribution because it can contain multiple files, is versionable, and has in-browser syntax highlighting and markdown rendering.
-- `scripts/upload_model_to_gist.sh `: uploads non-binary files in the model directory as a Github Gist and prints the Gist ID. If `gist_id` is already part of the `/readme.md` frontmatter, then updates existing Gist.
+`scripts/upload_model_to_gist.sh ` uploads non-binary files in the model directory as a Github Gist and prints the Gist ID. If `gist_id` is already part of the `/readme.md` frontmatter, then updates existing Gist.
Try doing `scripts/upload_model_to_gist.sh models/bvlc_alexnet` to test the uploading (don't forget to delete the uploaded gist afterward).
@@ -56,4 +58,13 @@ It is up to the user where to host the `.caffemodel` file.
We host our BVLC-provided models on our own server.
Dropbox also works fine (tip: make sure that `?dl=1` is appended to the end of the URL).
-- `scripts/download_model_binary.py `: downloads the `.caffemodel` from the URL specified in the `/readme.md` frontmatter and confirms SHA1.
+`scripts/download_model_binary.py ` downloads the `.caffemodel` from the URL specified in the `/readme.md` frontmatter and confirms SHA1.
+
+## BVLC model license
+
+The Caffe models bundled by the BVLC are released for unrestricted use.
+
+These models are trained on data from the [ImageNet project](http://www.image-net.org/) and training data includes internet photos that may be subject to copyright.
+
+Our present understanding as researchers is that there is no restriction placed on the open release of these learned model weights, since none of the original images are distributed in whole or in part.
+To the extent that the interpretation arises that weights are derivative works of the original copyright holder and they assert such a copyright, UC Berkeley makes no representations as to what use is allowed other than to consider our present release in the spirit of fair use in the academic mission of the university to disseminate knowledge and tools as broadly as possible without restriction.
diff --git a/examples/hdf5_classification.ipynb b/examples/hdf5_classification.ipynb
index b90b79d962c..19d27372754 100644
--- a/examples/hdf5_classification.ipynb
+++ b/examples/hdf5_classification.ipynb
@@ -40,6 +40,7 @@
"import shutil\n",
"import tempfile\n",
"\n",
+ "# You may need to 'pip install scikit-learn'\n",
"import sklearn\n",
"import sklearn.datasets\n",
"import sklearn.linear_model"
@@ -1070,4 +1071,4 @@
"metadata": {}
}
]
-}
\ No newline at end of file
+}
diff --git a/examples/web_demo/app.py b/examples/web_demo/app.py
index bbeff5eb362..c667ea94c11 100644
--- a/examples/web_demo/app.py
+++ b/examples/web_demo/app.py
@@ -10,7 +10,7 @@
import tornado.httpserver
import numpy as np
import pandas as pd
-import Image
+from PIL import Image
import cStringIO as StringIO
import urllib
import exifutil
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp
index b1ac3a93eff..cae1c3e4ee6 100644
--- a/include/caffe/common_layers.hpp
+++ b/include/caffe/common_layers.hpp
@@ -386,8 +386,8 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer {
bool handles_setup_;
cudnnHandle_t handle_;
- cudnnTensor4dDescriptor_t bottom_desc_;
- cudnnTensor4dDescriptor_t top_desc_;
+ cudnnTensorDescriptor_t bottom_desc_;
+ cudnnTensorDescriptor_t top_desc_;
};
#endif
diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp
index d9c32359b09..67255210a77 100644
--- a/include/caffe/data_layers.hpp
+++ b/include/caffe/data_layers.hpp
@@ -169,6 +169,8 @@ class HDF5DataLayer : public Layer {
unsigned int current_file_;
hsize_t current_row_;
std::vector > > hdf_blobs_;
+ std::vector data_permutation_;
+ std::vector file_permutation_;
};
*/
/**
diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp
index 0c306fb41bf..323215134c7 100644
--- a/include/caffe/neuron_layers.hpp
+++ b/include/caffe/neuron_layers.hpp
@@ -433,8 +433,8 @@ class CuDNNReLULayer : public ReLULayer {
bool handles_setup_;
cudnnHandle_t handle_;
- cudnnTensor4dDescriptor_t bottom_desc_;
- cudnnTensor4dDescriptor_t top_desc_;
+ cudnnTensorDescriptor_t bottom_desc_;
+ cudnnTensorDescriptor_t top_desc_;
};
#endif
@@ -516,8 +516,8 @@ class CuDNNSigmoidLayer : public SigmoidLayer {
bool handles_setup_;
cudnnHandle_t handle_;
- cudnnTensor4dDescriptor_t bottom_desc_;
- cudnnTensor4dDescriptor_t top_desc_;
+ cudnnTensorDescriptor_t bottom_desc_;
+ cudnnTensorDescriptor_t top_desc_;
};
#endif
@@ -601,8 +601,8 @@ class CuDNNTanHLayer : public TanHLayer {
bool handles_setup_;
cudnnHandle_t handle_;
- cudnnTensor4dDescriptor_t bottom_desc_;
- cudnnTensor4dDescriptor_t top_desc_;
+ cudnnTensorDescriptor_t bottom_desc_;
+ cudnnTensorDescriptor_t top_desc_;
};
#endif
@@ -654,6 +654,90 @@ class ThresholdLayer : public NeuronLayer {
Dtype threshold_;
};
+/**
+ * @brief Parameterized Rectified Linear Unit non-linearity @f$
+ * y_i = \max(0, x_i) + a_i \min(0, x_i)
+ * @f$. The differences from ReLULayer are 1) negative slopes are
+ * learnable though backprop and 2) negative slopes can vary across
+ * channels. The number of axes of input blob should be greater than or
+ * equal to 2. The 1st axis (0-based) is seen as channels.
+ */
+template
+class PReLULayer : public NeuronLayer {
+ public:
+ /**
+ * @param param provides PReLUParameter prelu_param,
+ * with PReLULayer options:
+ * - filler (\b optional, FillerParameter,
+ * default {'type': constant 'value':0.25}).
+ * - channel_shared (\b optional, default false).
+ * negative slopes are shared across channels.
+ */
+ explicit PReLULayer(const LayerParameter& param)
+ : NeuronLayer(param) {}
+
+ virtual void LayerSetUp(const vector*>& bottom,
+ const vector*>& top);
+
+ virtual void Reshape(const vector*>& bottom,
+ const vector*>& top);
+
+ virtual inline const char* type() const { return "PReLU"; }
+
+ protected:
+ /**
+ * @param bottom input Blob vector (length 1)
+ * -# @f$ (N \times C \times ...) @f$
+ * the inputs @f$ x @f$
+ * @param top output Blob vector (length 1)
+ * -# @f$ (N \times C \times ...) @f$
+ * the computed outputs for each channel @f$i@f$ @f$
+ * y_i = \max(0, x_i) + a_i \min(0, x_i)
+ * @f$.
+ */
+ virtual void Forward_cpu(const vector*>& bottom,
+ const vector*>& top);
+ virtual void Forward_gpu(const vector*>& bottom,
+ const vector*>& top);
+
+ /**
+ * @brief Computes the error gradient w.r.t. the PReLU inputs.
+ *
+ * @param top output Blob vector (length 1), providing the error gradient with
+ * respect to the outputs
+ * -# @f$ (N \times C \times ...) @f$
+ * containing error gradients @f$ \frac{\partial E}{\partial y} @f$
+ * with respect to computed outputs @f$ y @f$
+ * @param propagate_down see Layer::Backward.
+ * @param bottom input Blob vector (length 1)
+ * -# @f$ (N \times C \times ...) @f$
+ * the inputs @f$ x @f$; For each channel @f$i@f$, backward fills their
+ * diff with gradients @f$
+ * \frac{\partial E}{\partial x_i} = \left\{
+ * \begin{array}{lr}
+ * a_i \frac{\partial E}{\partial y_i} & \mathrm{if} \; x_i \le 0 \\
+ * \frac{\partial E}{\partial y_i} & \mathrm{if} \; x_i > 0
+ * \end{array} \right.
+ * @f$.
+ * If param_propagate_down_[0] is true, it fills the diff with gradients
+ * @f$
+ * \frac{\partial E}{\partial a_i} = \left\{
+ * \begin{array}{lr}
+ * \sum_{x_i} x_i \frac{\partial E}{\partial y_i} & \mathrm{if} \; x_i \le 0 \\
+ * 0 & \mathrm{if} \; x_i > 0
+ * \end{array} \right.
+ * @f$.
+ */
+ virtual void Backward_cpu(const vector*>& top,
+ const vector& propagate_down, const vector*>& bottom);
+ virtual void Backward_gpu(const vector*>& top,
+ const vector& propagate_down, const vector*>& bottom);
+
+ bool channel_shared_;
+ Blob multiplier_; // dot multipler for backward computation of params
+ Blob bottom_memory_; // memory for in-place computation
+};
+
} // namespace caffe
#endif // CAFFE_NEURON_LAYERS_HPP_
diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp
index 2564e0716ef..1b726de9564 100644
--- a/include/caffe/syncedmem.hpp
+++ b/include/caffe/syncedmem.hpp
@@ -12,7 +12,7 @@ namespace caffe {
// cudaMallocHost and cudaFree functions in order to create pinned memory.
// However, those codes rely on the existence of a cuda GPU (I don't know
// why that is a must since allocating memory should not be accessing the
-// GPU resorce, but it just creates an error as of Cuda 5.0) and will cause
+// GPU resource, but it just creates an error as of Cuda 5.0) and will cause
// problem when running on a machine without GPU. Thus, we simply define
// these two functions for safety and possible future change if the problem
// of calling cuda functions disappears in a future version.
diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp
index eaed7333df8..b531dd5fa7a 100644
--- a/include/caffe/util/cudnn.hpp
+++ b/include/caffe/util/cudnn.hpp
@@ -50,41 +50,45 @@ template class dataType;
template<> class dataType {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
+ static float oneval, zeroval;
+ static const void *one, *zero;
};
template<> class dataType {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
+ static double oneval, zeroval;
+ static const void *one, *zero;
};
template
-inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) {
- CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
+inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
+ CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}
template
-inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
+inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType::type,
- n, c, h, w, stride_n, stride_c, stride_h, stride_w));
+ n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}
template
-inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
+inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w) {
const int stride_w = 1;
const int stride_h = w * stride_w;
const int stride_c = h * stride_h;
const int stride_n = c * stride_c;
setTensor4dDesc(desc, n, c, h, w,
- stride_n, stride_c, stride_h, stride_w);
+ stride_n, stride_c, stride_h, stride_w);
}
template
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
- CUDNN_CHECK(cudnnSetFilterDescriptor(*desc, dataType::type,
+ CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType::type,
n, c, h, w));
}
@@ -95,29 +99,29 @@ inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {
template
inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
- cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter,
+ cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
- CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter,
+ CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}
template
-inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv,
+inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
- int h, int w, int stride_h, int stride_w) {
+ int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) {
switch (poolmethod) {
case PoolingParameter_PoolMethod_MAX:
*mode = CUDNN_POOLING_MAX;
break;
case PoolingParameter_PoolMethod_AVE:
- *mode = CUDNN_POOLING_AVERAGE;
+ *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
- CUDNN_CHECK(cudnnCreatePoolingDescriptor(conv));
- CUDNN_CHECK(cudnnSetPoolingDescriptor(*conv, *mode, h, w,
- stride_h, stride_w));
+ CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
+ CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
+ pad_h, pad_w, stride_h, stride_w));
}
} // namespace cudnn
diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp
index 6cb507a5780..cd0ab8babb0 100644
--- a/include/caffe/vision_layers.hpp
+++ b/include/caffe/vision_layers.hpp
@@ -246,11 +246,13 @@ class CuDNNConvolutionLayer : public ConvolutionLayer {
bool handles_setup_;
cudnnHandle_t* handle_;
cudaStream_t* stream_;
- vector bottom_descs_, top_descs_;
- cudnnTensor4dDescriptor_t bias_desc_;
+ vector bottom_descs_, top_descs_;
+ cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector conv_descs_;
int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
+ size_t workspaceSizeInBytes;
+ void *workspace;
};
#endif
@@ -445,7 +447,7 @@ class CuDNNPoolingLayer : public PoolingLayer {
bool handles_setup_;
cudnnHandle_t handle_;
- cudnnTensor4dDescriptor_t bottom_desc_, top_desc_;
+ cudnnTensorDescriptor_t bottom_desc_, top_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
cudnnPoolingMode_t mode_;
};
diff --git a/python/classify.py b/python/classify.py
index 1e26ff48a23..4544c51b4c2 100755
--- a/python/classify.py
+++ b/python/classify.py
@@ -112,7 +112,7 @@ def main(argv):
# Load numpy array (.npy), directory glob (*.jpg), or image file.
args.input_file = os.path.expanduser(args.input_file)
if args.input_file.endswith('npy'):
- print("Loading file: %s" %s args.input_file
+ print("Loading file: %s" % args.input_file)
inputs = np.load(args.input_file)
elif os.path.isdir(args.input_file):
print("Loading folder: %s" % args.input_file)
diff --git a/python/requirements.txt b/python/requirements.txt
index 6a90fd6f5b4..7bc164a42b5 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -2,7 +2,6 @@ Cython>=0.19.2
numpy>=1.7.1
scipy>=0.13.2
scikit-image>=0.9.3
-scikit-learn>=0.14.1
matplotlib>=1.3.1
ipython>=1.1.0
h5py>=2.2.0
@@ -14,4 +13,4 @@ python-dateutil>=1.4,<2
protobuf>=2.5.0
python-gflags>=2.0
pyyaml>=3.10
-Pillow>=2.7.0
+Pillow>=2.3.0
diff --git a/scripts/travis/travis_install.sh b/scripts/travis/travis_install.sh
index 82f386cf029..0e8c37861b0 100755
--- a/scripts/travis/travis_install.sh
+++ b/scripts/travis/travis_install.sh
@@ -67,4 +67,3 @@ export PATH=/home/travis/miniconda/bin:$PATH
conda update --yes conda
conda install --yes numpy scipy matplotlib scikit-image pip
pip install protobuf
-rm /home/travis/miniconda/lib/libm.*
diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp
index 4a69ca20d0a..104d2b9d669 100644
--- a/src/caffe/layers/cudnn_conv_layer.cpp
+++ b/src/caffe/layers/cudnn_conv_layer.cpp
@@ -24,6 +24,8 @@ void CuDNNConvolutionLayer::LayerSetUp(
// Initialize CUDA streams and cuDNN.
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+ workspaceSizeInBytes = 0;
+ workspace = NULL;
for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
@@ -43,10 +45,10 @@ void CuDNNConvolutionLayer::LayerSetUp(
// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
- cudnnTensor4dDescriptor_t bottom_desc;
+ cudnnTensorDescriptor_t bottom_desc;
cudnn::createTensor4dDesc(&bottom_desc);
bottom_descs_.push_back(bottom_desc);
- cudnnTensor4dDescriptor_t top_desc;
+ cudnnTensorDescriptor_t top_desc;
cudnn::createTensor4dDesc(&top_desc);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
@@ -104,12 +106,12 @@ CuDNNConvolutionLayer::~CuDNNConvolutionLayer() {
if (!handles_setup_) { return; }
for (int i = 0; i < bottom_descs_.size(); i++) {
- cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
- cudnnDestroyTensor4dDescriptor(top_descs_[i]);
+ cudnnDestroyTensorDescriptor(bottom_descs_[i]);
+ cudnnDestroyTensorDescriptor(top_descs_[i]);
cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
}
if (this->bias_term_) {
- cudnnDestroyTensor4dDescriptor(bias_desc_);
+ cudnnDestroyTensorDescriptor(bias_desc_);
}
cudnnDestroyFilterDescriptor(filter_desc_);
diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu
index 071014e1b48..4a1a4c4f4f2 100644
--- a/src/caffe/layers/cudnn_conv_layer.cu
+++ b/src/caffe/layers/cudnn_conv_layer.cu
@@ -19,23 +19,70 @@ void CuDNNConvolutionLayer::Forward_gpu(
Dtype* top_data = top[i]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
+ size_t workspace_limit_bytes = this->kernel_h_ *
+ this->kernel_w_ *
+ this->channels_ *
+ sizeof(int) + 1;
+
// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
+ cudnnConvolutionFwdAlgo_t algo;
+
+ // pick the convolution algorithm
+ // TODO(shelhamer) this should be done during reshape
+ // TODO(shelhamer) the choice of automatic or manual algorithm picking
+ // should be exposed in proto
+ CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g],
+ bottom_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ top_descs_[i],
+ CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+ workspace_limit_bytes, // memoryLimitInBytes,
+ &algo));
+
+ // get minimum size of the workspace needed for the desired algorithm
+ size_t workspaceSizeInBytes_temp = 0;
+
+ CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g],
+ bottom_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ top_descs_[i],
+ algo,
+ &workspaceSizeInBytes_temp));
+
+ if (workspaceSizeInBytes_temp > workspaceSizeInBytes) {
+ workspaceSizeInBytes = workspaceSizeInBytes_temp;
+ // free the existing workspace and allocate a new (larger) one
+ cudaFree(this->workspace);
+ cudaError_t err = cudaMalloc(&(this->workspace), workspaceSizeInBytes);
+ if (err != cudaSuccess) {
+ // force zero memory path
+ algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+ workspace = NULL;
+ workspaceSizeInBytes = 0;
+ }
+ }
+
// Filters.
CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
- bottom_descs_[i], bottom_data + bottom_offset_ * g,
- filter_desc_, weight + weight_offset_ * g,
- conv_descs_[i],
- top_descs_[i], top_data + top_offset_ * g,
- CUDNN_RESULT_NO_ACCUMULATE));
+ cudnn::dataType::one,
+ bottom_descs_[i], bottom_data + bottom_offset_ * g,
+ filter_desc_, weight + weight_offset_ * g,
+ conv_descs_[i],
+ algo, workspace, workspaceSizeInBytes,
+ cudnn::dataType::zero,
+ top_descs_[i], top_data + top_offset_ * g));
// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
- Dtype alpha = 1.;
- CUDNN_CHECK(cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha,
- bias_desc_, bias_data + bias_offset_ * g,
- top_descs_[i], top_data + top_offset_ * g));
+ CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C,
+ cudnn::dataType::one,
+ bias_desc_, bias_data + bias_offset_ * g,
+ cudnn::dataType::one,
+ top_descs_[i], top_data + top_offset_ * g));
}
}
@@ -68,20 +115,22 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top,
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g],
- top_descs_[i], top_diff + top_offset_ * g,
- bias_desc_, bias_diff + bias_offset_ * g,
- CUDNN_RESULT_ACCUMULATE));
+ cudnn::dataType::one,
+ top_descs_[i], top_diff + top_offset_ * g,
+ cudnn::dataType::one,
+ bias_desc_, bias_diff + bias_offset_ * g));
}
// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
- bottom_descs_[i], bottom_data + bottom_offset_ * g,
- top_descs_[i], top_diff + top_offset_ * g,
- conv_descs_[i],
- filter_desc_, weight_diff + weight_offset_ * g,
- CUDNN_RESULT_ACCUMULATE));
+ cudnn::dataType::one,
+ bottom_descs_[i], bottom_data + bottom_offset_ * g,
+ top_descs_[i], top_diff + top_offset_ * g,
+ conv_descs_[i],
+ cudnn::dataType::one,
+ filter_desc_, weight_diff + weight_offset_ * g));
}
// Gradient w.r.t. bottom data.
@@ -91,11 +140,12 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top,
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
- filter_desc_, weight + weight_offset_ * g,
- top_descs_[i], top_diff + top_offset_ * g,
- conv_descs_[i],
- bottom_descs_[i], bottom_diff + bottom_offset_ * g,
- CUDNN_RESULT_NO_ACCUMULATE));
+ cudnn::dataType::one,
+ filter_desc_, weight + weight_offset_ * g,
+ top_descs_[i], top_diff + top_offset_ * g,
+ conv_descs_[i],
+ cudnn::dataType::zero,
+ bottom_descs_[i], bottom_diff + bottom_offset_ * g));
}
}
diff --git a/src/caffe/layers/cudnn_pooling_layer.cpp b/src/caffe/layers/cudnn_pooling_layer.cpp
index dd90195637b..c92c4e477b5 100644
--- a/src/caffe/layers/cudnn_pooling_layer.cpp
+++ b/src/caffe/layers/cudnn_pooling_layer.cpp
@@ -13,15 +13,13 @@ template
void CuDNNPoolingLayer::LayerSetUp(const vector*>& bottom,
const vector*>& top) {
PoolingLayer::LayerSetUp(bottom, top);
- // Sanity check: CUDNN currently only supports pad == 0.
- CHECK_EQ(this->pad_h_, 0);
- CHECK_EQ(this->pad_w_, 0);
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc(&bottom_desc_);
cudnn::createTensor4dDesc(&top_desc_);
cudnn::createPoolingDesc(&pooling_desc_,
this->layer_param_.pooling_param().pool(), &mode_,
- this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_);
+ this->kernel_h_, this->kernel_w_, this->pad_h_, this->pad_w_,
+ this->stride_h_, this->stride_w_);
handles_setup_ = true;
}
@@ -40,8 +38,8 @@ CuDNNPoolingLayer::~CuDNNPoolingLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(bottom_desc_);
- cudnnDestroyTensor4dDescriptor(top_desc_);
+ cudnnDestroyTensorDescriptor(bottom_desc_);
+ cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroyPoolingDescriptor(pooling_desc_);
cudnnDestroy(handle_);
}
diff --git a/src/caffe/layers/cudnn_pooling_layer.cu b/src/caffe/layers/cudnn_pooling_layer.cu
index 1c113aad75f..a952b855a48 100644
--- a/src/caffe/layers/cudnn_pooling_layer.cu
+++ b/src/caffe/layers/cudnn_pooling_layer.cu
@@ -15,7 +15,10 @@ void CuDNNPoolingLayer::Forward_gpu(const vector*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_,
- bottom_desc_, bottom_data, top_desc_, top_data));
+ cudnn::dataType::one,
+ bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ top_desc_, top_data));
}
template
@@ -29,8 +32,11 @@ void CuDNNPoolingLayer::Backward_gpu(const vector*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_,
- top_desc_, top_data, top_desc_, top_diff,
- bottom_desc_, bottom_data, bottom_desc_, bottom_diff));
+ cudnn::dataType::one,
+ top_desc_, top_data, top_desc_, top_diff,
+ bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNPoolingLayer);
diff --git a/src/caffe/layers/cudnn_relu_layer.cpp b/src/caffe/layers/cudnn_relu_layer.cpp
index 0b8a6bc3248..759d83984ef 100644
--- a/src/caffe/layers/cudnn_relu_layer.cpp
+++ b/src/caffe/layers/cudnn_relu_layer.cpp
@@ -35,8 +35,8 @@ CuDNNReLULayer::~CuDNNReLULayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
- cudnnDestroyTensor4dDescriptor(this->top_desc_);
+ cudnnDestroyTensorDescriptor(this->bottom_desc_);
+ cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
diff --git a/src/caffe/layers/cudnn_relu_layer.cu b/src/caffe/layers/cudnn_relu_layer.cu
index 862508707a0..21d14857dd2 100644
--- a/src/caffe/layers/cudnn_relu_layer.cu
+++ b/src/caffe/layers/cudnn_relu_layer.cu
@@ -18,8 +18,11 @@ void CuDNNReLULayer::Forward_gpu(const vector*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnActivationForward(this->handle_,
- CUDNN_ACTIVATION_RELU,
- this->bottom_desc_, bottom_data, this->top_desc_, top_data));
+ CUDNN_ACTIVATION_RELU,
+ cudnn::dataType::one,
+ this->bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ this->top_desc_, top_data));
}
template
@@ -40,9 +43,12 @@ void CuDNNReLULayer::Backward_gpu(const vector*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
- CUDNN_ACTIVATION_RELU,
- this->top_desc_, top_data, this->top_desc_, top_diff,
- this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
+ CUDNN_ACTIVATION_RELU,
+ cudnn::dataType::one,
+ this->top_desc_, top_data, this->top_desc_, top_diff,
+ this->bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ this->bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer);
diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cpp b/src/caffe/layers/cudnn_sigmoid_layer.cpp
index 67bd9c373b0..32637873d46 100644
--- a/src/caffe/layers/cudnn_sigmoid_layer.cpp
+++ b/src/caffe/layers/cudnn_sigmoid_layer.cpp
@@ -35,8 +35,8 @@ CuDNNSigmoidLayer::~CuDNNSigmoidLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
- cudnnDestroyTensor4dDescriptor(this->top_desc_);
+ cudnnDestroyTensorDescriptor(this->bottom_desc_);
+ cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cu b/src/caffe/layers/cudnn_sigmoid_layer.cu
index 31b094e25d4..7a06cf721da 100644
--- a/src/caffe/layers/cudnn_sigmoid_layer.cu
+++ b/src/caffe/layers/cudnn_sigmoid_layer.cu
@@ -13,8 +13,11 @@ void CuDNNSigmoidLayer::Forward_gpu(const vector*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnActivationForward(this->handle_,
- CUDNN_ACTIVATION_SIGMOID,
- this->bottom_desc_, bottom_data, this->top_desc_, top_data));
+ CUDNN_ACTIVATION_SIGMOID,
+ cudnn::dataType::one,
+ this->bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ this->top_desc_, top_data));
}
template
@@ -30,9 +33,12 @@ void CuDNNSigmoidLayer::Backward_gpu(const vector*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
- CUDNN_ACTIVATION_SIGMOID,
- this->top_desc_, top_data, this->top_desc_, top_diff,
- this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
+ CUDNN_ACTIVATION_SIGMOID,
+ cudnn::dataType::one,
+ this->top_desc_, top_data, this->top_desc_, top_diff,
+ this->bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ this->bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer);
diff --git a/src/caffe/layers/cudnn_softmax_layer.cpp b/src/caffe/layers/cudnn_softmax_layer.cpp
index 211701cad49..77a3225adcd 100644
--- a/src/caffe/layers/cudnn_softmax_layer.cpp
+++ b/src/caffe/layers/cudnn_softmax_layer.cpp
@@ -39,8 +39,8 @@ CuDNNSoftmaxLayer::~CuDNNSoftmaxLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(bottom_desc_);
- cudnnDestroyTensor4dDescriptor(top_desc_);
+ cudnnDestroyTensorDescriptor(bottom_desc_);
+ cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroy(handle_);
}
diff --git a/src/caffe/layers/cudnn_softmax_layer.cu b/src/caffe/layers/cudnn_softmax_layer.cu
index f328afdd831..a9e2fcefaf7 100644
--- a/src/caffe/layers/cudnn_softmax_layer.cu
+++ b/src/caffe/layers/cudnn_softmax_layer.cu
@@ -17,8 +17,11 @@ void CuDNNSoftmaxLayer::Forward_gpu(const vector*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnSoftmaxForward(handle_, CUDNN_SOFTMAX_ACCURATE,
- CUDNN_SOFTMAX_MODE_CHANNEL,
- bottom_desc_, bottom_data, top_desc_, top_data));
+ CUDNN_SOFTMAX_MODE_CHANNEL,
+ cudnn::dataType::one,
+ bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ top_desc_, top_data));
}
template
@@ -29,9 +32,13 @@ void CuDNNSoftmaxLayer::Backward_gpu(const vector*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
CUDNN_CHECK(cudnnSoftmaxBackward(handle_, CUDNN_SOFTMAX_ACCURATE,
- CUDNN_SOFTMAX_MODE_CHANNEL,
- top_desc_, top_data, top_desc_, top_diff, bottom_desc_, bottom_diff));
+ CUDNN_SOFTMAX_MODE_CHANNEL,
+ cudnn::dataType::one,
+ top_desc_, top_data, top_desc_, top_diff,
+ cudnn::dataType::zero,
+ bottom_desc_, bottom_diff));
}
}
diff --git a/src/caffe/layers/cudnn_tanh_layer.cpp b/src/caffe/layers/cudnn_tanh_layer.cpp
index b1d2b86384e..376faad324d 100644
--- a/src/caffe/layers/cudnn_tanh_layer.cpp
+++ b/src/caffe/layers/cudnn_tanh_layer.cpp
@@ -35,8 +35,8 @@ CuDNNTanHLayer::~CuDNNTanHLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
- cudnnDestroyTensor4dDescriptor(this->top_desc_);
+ cudnnDestroyTensorDescriptor(this->bottom_desc_);
+ cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
diff --git a/src/caffe/layers/cudnn_tanh_layer.cu b/src/caffe/layers/cudnn_tanh_layer.cu
index bf9ec7cfac4..d287f6fee85 100644
--- a/src/caffe/layers/cudnn_tanh_layer.cu
+++ b/src/caffe/layers/cudnn_tanh_layer.cu
@@ -13,8 +13,11 @@ void CuDNNTanHLayer::Forward_gpu(const vector*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnActivationForward(this->handle_,
- CUDNN_ACTIVATION_TANH,
- this->bottom_desc_, bottom_data, this->top_desc_, top_data));
+ CUDNN_ACTIVATION_TANH,
+ cudnn::dataType::one,
+ this->bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ this->top_desc_, top_data));
}
template
@@ -29,10 +32,14 @@ void CuDNNTanHLayer::Backward_gpu(const vector*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
- CUDNN_ACTIVATION_TANH,
- this->top_desc_, top_data, this->top_desc_, top_diff,
- this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
+ CUDNN_ACTIVATION_TANH,
+ cudnn::dataType::one,
+ this->top_desc_, top_data, this->top_desc_, top_diff,
+ this->bottom_desc_, bottom_data,
+ cudnn::dataType::zero,
+ this->bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer);
diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp
index 1ceb6c24431..8a782f7e524 100644
--- a/src/caffe/layers/hdf5_data_layer.cpp
+++ b/src/caffe/layers/hdf5_data_layer.cpp
@@ -14,9 +14,9 @@
#include "hdf5_hl.h"
#include "stdint.h"
+#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
-#include "caffe/vision_layers.hpp"
namespace caffe {
@@ -48,11 +48,25 @@ void HDF5DataLayer::LoadHDF5FileData(const char* filename) {
CHECK_GE(status, 0) << "Failed to close HDF5 file: " << filename;
// MinTopBlobs==1 guarantees at least one top blob
- int num = hdf_blobs_[0]->num();
+ CHECK_GE(hdf_blobs_[0]->num_axes(), 1) << "Input must have at least 1 axis.";
+ const int num = hdf_blobs_[0]->shape(0);
for (int i = 1; i < top_size; ++i) {
- CHECK_EQ(hdf_blobs_[i]->num(), num);
+ CHECK_EQ(hdf_blobs_[i]->shape(0), num);
+ }
+ // Default to identity permutation.
+ data_permutation_.clear();
+ data_permutation_.resize(hdf_blobs_[0]->shape(0));
+ for (int i = 0; i < hdf_blobs_[0]->shape(0); i++)
+ data_permutation_[i] = i;
+
+ // Shuffle if needed.
+ if (this->layer_param_.hdf5_data_param().shuffle()) {
+ std::random_shuffle(data_permutation_.begin(), data_permutation_.end());
+ DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->shape(0)
+ << " rows (shuffled)";
+ } else {
+ DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->shape(0) << " rows";
}
- DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->num() << " rows";
}
template
@@ -81,8 +95,20 @@ void HDF5DataLayer::LayerSetUp(const vector*>& bottom,
CHECK_GE(num_files_, 1) << "Must have at least 1 HDF5 filename listed in "
<< source;
+ file_permutation_.clear();
+ file_permutation_.resize(num_files_);
+ // Default to identity permutation.
+ for (int i = 0; i < num_files_; i++) {
+ file_permutation_[i] = i;
+ }
+
+ // Shuffle if needed.
+ if (this->layer_param_.hdf5_data_param().shuffle()) {
+ std::random_shuffle(file_permutation_.begin(), file_permutation_.end());
+ }
+
// Load the first HDF5 file and initialize the line counter.
- LoadHDF5FileData(hdf_filenames_[current_file_].c_str());
+ LoadHDF5FileData(hdf_filenames_[file_permutation_[current_file_]].c_str());
current_row_ = 0;
// Reshape blobs.
@@ -104,22 +130,29 @@ void HDF5DataLayer::Forward_cpu(const vector*>& bottom,
const vector*>& top) {
const int batch_size = this->layer_param_.hdf5_data_param().batch_size();
for (int i = 0; i < batch_size; ++i, ++current_row_) {
- if (current_row_ == hdf_blobs_[0]->num()) {
+ if (current_row_ == hdf_blobs_[0]->shape(0)) {
if (num_files_ > 1) {
++current_file_;
if (current_file_ == num_files_) {
current_file_ = 0;
+ if (this->layer_param_.hdf5_data_param().shuffle()) {
+ std::random_shuffle(file_permutation_.begin(),
+ file_permutation_.end());
+ }
DLOG(INFO) << "Looping around to first file.";
}
- LoadHDF5FileData(hdf_filenames_[current_file_].c_str());
+ LoadHDF5FileData(
+ hdf_filenames_[file_permutation_[current_file_]].c_str());
}
current_row_ = 0;
+ if (this->layer_param_.hdf5_data_param().shuffle())
+ std::random_shuffle(data_permutation_.begin(), data_permutation_.end());
}
for (int j = 0; j < this->layer_param_.top_size(); ++j) {
- int data_dim = top[j]->count() / top[j]->num();
+ int data_dim = top[j]->count() / top[j]->shape(0);
caffe_copy(data_dim,
- &hdf_blobs_[j]->cpu_data()[current_row_ * data_dim],
- &top[j]->mutable_cpu_data()[i * data_dim]);
+ &hdf_blobs_[j]->cpu_data()[data_permutation_[current_row_]
+ * data_dim], &top[j]->mutable_cpu_data()[i * data_dim]);
}
}
}
diff --git a/src/caffe/layers/hdf5_data_layer.cu b/src/caffe/layers/hdf5_data_layer.cu
index 02e3821d104..5e3e4ced141 100644
--- a/src/caffe/layers/hdf5_data_layer.cu
+++ b/src/caffe/layers/hdf5_data_layer.cu
@@ -10,9 +10,9 @@ TODO:
#include "hdf5.h"
#include "hdf5_hl.h"
+#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
-#include "caffe/vision_layers.hpp"
namespace caffe {
@@ -21,22 +21,29 @@ void HDF5DataLayer::Forward_gpu(const vector*>& bottom,
const vector*>& top) {
const int batch_size = this->layer_param_.hdf5_data_param().batch_size();
for (int i = 0; i < batch_size; ++i, ++current_row_) {
- if (current_row_ == hdf_blobs_[0]->num()) {
+ if (current_row_ == hdf_blobs_[0]->shape(0)) {
if (num_files_ > 1) {
current_file_ += 1;
if (current_file_ == num_files_) {
current_file_ = 0;
+ if (this->layer_param_.hdf5_data_param().shuffle()) {
+ std::random_shuffle(file_permutation_.begin(),
+ file_permutation_.end());
+ }
DLOG(INFO) << "Looping around to first file.";
}
- LoadHDF5FileData(hdf_filenames_[current_file_].c_str());
+ LoadHDF5FileData(
+ hdf_filenames_[file_permutation_[current_file_]].c_str());
}
current_row_ = 0;
+ if (this->layer_param_.hdf5_data_param().shuffle())
+ std::random_shuffle(data_permutation_.begin(), data_permutation_.end());
}
for (int j = 0; j < this->layer_param_.top_size(); ++j) {
- int data_dim = top[j]->count() / top[j]->num();
+ int data_dim = top[j]->count() / top[j]->shape(0);
caffe_copy(data_dim,
- &hdf_blobs_[j]->cpu_data()[current_row_ * data_dim],
- &top[j]->mutable_gpu_data()[i * data_dim]);
+ &hdf_blobs_[j]->cpu_data()[data_permutation_[current_row_]
+ * data_dim], &top[j]->mutable_gpu_data()[i * data_dim]);
}
}
}
diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu
index 58c39926c72..24aa6a30130 100644
--- a/src/caffe/layers/lrn_layer.cu
+++ b/src/caffe/layers/lrn_layer.cu
@@ -26,26 +26,24 @@ __global__ void LRNFillScale(const int nthreads, const Dtype* in,
Dtype accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
- while (head < post_pad) {
+ while (head < post_pad && head < channels) {
accum_scale += in[head * step] * in[head * step];
++head;
}
- // until we reach size, nothing needs to be subtracted
- while (head < size) {
- accum_scale += in[head * step] * in[head * step];
- scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
- ++head;
- }
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
- accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+ if (head - size >= 0) {
+ accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+ }
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
// subtract only
while (head < channels + post_pad) {
- accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+ if (head - size >= 0) {
+ accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+ }
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
@@ -143,26 +141,19 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
int post_pad = size - pre_pad - 1;
Dtype accum_ratio = 0;
// accumulate values
- while (head < post_pad) {
+ while (head < post_pad && head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
++head;
}
- // until we reach size, nothing needs to be subtracted
- while (head < size) {
- accum_ratio += top_diff[head * step] * top_data[head * step] /
- scale[head * step];
- bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
- * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
- bottom_data[(head - post_pad) * step] * accum_ratio;
- ++head;
- }
// both add and subtract
while (head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
- accum_ratio -= top_diff[(head - size) * step] *
- top_data[(head - size) * step] / scale[(head - size) * step];
+ if (head - size >= 0) {
+ accum_ratio -= top_diff[(head - size) * step] *
+ top_data[(head - size) * step] / scale[(head - size) * step];
+ }
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
* pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
@@ -170,8 +161,10 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
}
// subtract only
while (head < channels + post_pad) {
- accum_ratio -= top_diff[(head - size) * step] *
- top_data[(head - size) * step] / scale[(head - size) * step];
+ if (head - size >= 0) {
+ accum_ratio -= top_diff[(head - size) * step] *
+ top_data[(head - size) * step] / scale[(head - size) * step];
+ }
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
* pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
diff --git a/src/caffe/layers/prelu_layer.cpp b/src/caffe/layers/prelu_layer.cpp
new file mode 100644
index 00000000000..7119a274dd3
--- /dev/null
+++ b/src/caffe/layers/prelu_layer.cpp
@@ -0,0 +1,140 @@
+#include
+#include
+
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template
+void PReLULayer::LayerSetUp(const vector*>& bottom,
+ const vector*>& top) {
+ CHECK_GE(bottom[0]->num_axes(), 2)
+ << "Number of axes of bottom blob must be >=2.";
+ PReLUParameter prelu_param = this->layer_param().prelu_param();
+ int channels = bottom[0]->channels();
+ channel_shared_ = prelu_param.channel_shared();
+ if (this->blobs_.size() > 0) {
+ LOG(INFO) << "Skipping parameter initialization";
+ } else {
+ this->blobs_.resize(1);
+ if (channel_shared_) {
+ this->blobs_[0].reset(new Blob(vector(0)));
+ } else {
+ this->blobs_[0].reset(new Blob(vector(1, channels)));
+ }
+ shared_ptr > filler;
+ if (prelu_param.has_filler()) {
+ filler.reset(GetFiller(prelu_param.filler()));
+ } else {
+ FillerParameter filler_param;
+ filler_param.set_type("constant");
+ filler_param.set_value(0.25);
+ filler.reset(GetFiller(filler_param));
+ }
+ filler->Fill(this->blobs_[0].get());
+ }
+ if (channel_shared_) {
+ CHECK_EQ(this->blobs_[0]->count(), 1)
+ << "Negative slope size is inconsistent with prototxt config";
+ } else {
+ CHECK_EQ(this->blobs_[0]->count(), channels)
+ << "Negative slope size is inconsistent with prototxt config";
+ }
+
+ // Propagate gradients to the parameters (as directed by backward pass).
+ this->param_propagate_down_.resize(this->blobs_.size(), true);
+ multiplier_.Reshape(vector(1, bottom[0]->count() / bottom[0]->num()));
+ caffe_set(multiplier_.count(), Dtype(1), multiplier_.mutable_cpu_data());
+}
+
+template
+void PReLULayer::Reshape(const vector*>& bottom,
+ const vector*>& top) {
+ CHECK_GE(bottom[0]->num_axes(), 2)
+ << "Number of axes of bottom blob must be >=2.";
+ top[0]->ReshapeLike(*bottom[0]);
+ if (bottom[0] == top[0]) {
+ // For in-place computation
+ bottom_memory_.ReshapeLike(*bottom[0]);
+ }
+}
+
+template
+void PReLULayer::Forward_cpu(const vector*>& bottom,
+ const vector*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+ const int dim = bottom[0]->count(2);
+ const int channels = bottom[0]->channels();
+ const Dtype* slope_data = this->blobs_[0]->cpu_data();
+
+ // For in-place computation
+ if (bottom[0] == top[0]) {
+ caffe_copy(count, bottom_data, bottom_memory_.mutable_cpu_data());
+ }
+
+ // if channel_shared, channel index in the following computation becomes
+ // always zero.
+ const int div_factor = channel_shared_ ? channels : 1;
+ for (int i = 0; i < count; ++i) {
+ int c = (i / dim) % channels / div_factor;
+ top_data[i] = std::max(bottom_data[i], Dtype(0))
+ + slope_data[c] * std::min(bottom_data[i], Dtype(0));
+ }
+}
+
+template
+void PReLULayer::Backward_cpu(const vector*>& top,
+ const vector& propagate_down,
+ const vector*>& bottom) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ const Dtype* slope_data = this->blobs_[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const int count = bottom[0]->count();
+ const int dim = bottom[0]->count(2);
+ const int channels = bottom[0]->channels();
+
+ // For in-place computation
+ if (top[0] == bottom[0]) {
+ bottom_data = bottom_memory_.cpu_data();
+ }
+
+ // if channel_shared, channel index in the following computation becomes
+ // always zero.
+ const int div_factor = channel_shared_ ? channels : 1;
+
+ // Propagte to param
+ // Since to write bottom diff will affect top diff if top and bottom blobs
+ // are identical (in-place computaion), we first compute param backward to
+ // keep top_diff unchanged.
+ if (this->param_propagate_down_[0]) {
+ Dtype* slope_diff = this->blobs_[0]->mutable_cpu_diff();
+ caffe_set(this->blobs_[0]->count(), Dtype(0), slope_diff);
+ for (int i = 0; i < count; ++i) {
+ int c = (i / dim) % channels / div_factor;
+ slope_diff[c] += top_diff[i] * bottom_data[i] * (bottom_data[i] <= 0);
+ }
+ }
+ // Propagate to bottom
+ if (propagate_down[0]) {
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ for (int i = 0; i < count; ++i) {
+ int c = (i / dim) % channels / div_factor;
+ bottom_diff[i] = top_diff[i] * ((bottom_data[i] > 0)
+ + slope_data[c] * (bottom_data[i] <= 0));
+ }
+ }
+}
+
+
+#ifdef CPU_ONLY
+STUB_GPU(PReLULayer);
+#endif
+
+INSTANTIATE_CLASS(PReLULayer);
+REGISTER_LAYER_CLASS(PReLU);
+
+} // namespace caffe
diff --git a/src/caffe/layers/prelu_layer.cu b/src/caffe/layers/prelu_layer.cu
new file mode 100644
index 00000000000..fd0eda5d191
--- /dev/null
+++ b/src/caffe/layers/prelu_layer.cu
@@ -0,0 +1,130 @@
+#include
+#include
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+// CUDA kernele for forward
+template
+__global__ void PReLUForward(const int n, const int channels, const int dim,
+ const Dtype* in, Dtype* out, const Dtype* slope_data,
+ const int div_factor) {
+ CUDA_KERNEL_LOOP(index, n) {
+ int c = (index / dim) % channels / div_factor;
+ out[index] = in[index] > 0 ? in[index] : in[index] * slope_data[c];
+ }
+}
+
+// CUDA kernel for bottom backward
+template
+__global__ void PReLUBackward(const int n, const int channels, const int dim,
+ const Dtype* in_diff, const Dtype* in_data, Dtype* out_diff,
+ const Dtype* slope_data, const int div_factor) {
+ CUDA_KERNEL_LOOP(index, n) {
+ int c = (index / dim) % channels / div_factor;
+ out_diff[index] = in_diff[index] * ((in_data[index] > 0)
+ + (in_data[index] <= 0) * slope_data[c]);
+ }
+}
+
+// CUDA kernel for element-wise parameter backward
+template
+__global__ void PReLUParamBackward(const int n, const Dtype* in_diff,
+ const Dtype* in_data, Dtype* out_diff) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out_diff[index] = in_diff[index] * in_data[index] * (in_data[index] <= 0);
+ }
+}
+
+template
+void PReLULayer::Forward_gpu(const vector*>& bottom,
+ const vector*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ const int dim = bottom[0]->count(2);
+ const int channels = bottom[0]->channels();
+ const Dtype* slope_data = this->blobs_[0]->gpu_data();
+ const int div_factor = channel_shared_ ? channels : 1;
+
+ // For in-place computation
+ if (top[0] == bottom[0]) {
+ caffe_copy(count, bottom_data, bottom_memory_.mutable_gpu_data());
+ }
+
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ PReLUForward<<>>(
+ count, channels, dim, bottom_data, top_data, slope_data, div_factor);
+ CUDA_POST_KERNEL_CHECK;
+}
+
+template
+void PReLULayer::Backward_gpu(const vector*>& top,
+ const vector& propagate_down,
+ const vector*>& bottom) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const int count = bottom[0]->count();
+ const int dim = bottom[0]->count(2);
+ const int channels = bottom[0]->channels();
+
+ // For in-place computation
+ if (top[0] == bottom[0]) {
+ bottom_data = bottom_memory_.gpu_data();
+ }
+
+ // Propagte to param
+ // Since to write bottom diff will affect top diff if top and bottom blobs
+ // are identical (in-place computaion), we first compute param backward to
+ // keep top_diff unchanged.
+ if (this->param_propagate_down_[0]) {
+ Dtype* slope_diff = this->blobs_[0]->mutable_gpu_diff();
+ // slope_diff is set as 0, then accumulated over batches
+ caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), slope_diff);
+ int cdim = channels * dim;
+ Dtype dsum = 0.;
+ for (int n = 0; n < bottom[0]->num(); ++n) {
+ Dtype* temp_buff = multiplier_.mutable_gpu_diff();
+ // compute element-wise diff
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ PReLUParamBackward<<>>(
+ cdim, top_diff + top[0]->offset(n),
+ bottom_data + bottom[0]->offset(n), multiplier_.mutable_gpu_diff());
+ CUDA_POST_KERNEL_CHECK;
+ if (channel_shared_) {
+ Dtype d;
+ caffe_gpu_dot(channels * dim, multiplier_.gpu_diff(),
+ multiplier_.gpu_data(), &d);
+ dsum += d;
+ } else {
+ caffe_gpu_gemv(CblasNoTrans, channels, dim, 1.,
+ multiplier_.gpu_diff(), multiplier_.gpu_data(), 1.,
+ slope_diff);
+ }
+ }
+ if (channel_shared_) {
+ caffe_gpu_set(this->blobs_[0]->count(), Dtype(dsum), slope_diff);
+ }
+ }
+ // Propagate to bottom
+ if (propagate_down[0]) {
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const Dtype* slope_data = this->blobs_[0]->gpu_data();
+ int div_factor = channel_shared_ ? channels : 1;
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ PReLUBackward<<>>(
+ count, channels, dim, top_diff, bottom_data, bottom_diff, slope_data,
+ div_factor);
+ CUDA_POST_KERNEL_CHECK;
+ }
+}
+
+
+INSTANTIATE_LAYER_GPU_FUNCS(PReLULayer);
+
+
+} // namespace caffe
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index e523efa50f1..5b21cf20028 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -259,7 +259,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 131 (last added: python_param)
+// LayerParameter next available layer-specific ID: 132 (last added: prelu_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@@ -323,6 +323,7 @@ message LayerParameter {
optional MVNParameter mvn_param = 120;
optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122;
+ optional PReLUParameter prelu_param = 131;
optional PythonParameter python_param = 130;
optional ReLUParameter relu_param = 123;
optional SigmoidParameter sigmoid_param = 124;
@@ -517,6 +518,13 @@ message HDF5DataParameter {
optional string source = 1;
// Specify the batch size.
optional uint32 batch_size = 2;
+
+ // Specify whether to shuffle the data.
+ // If shuffle == true, the ordering of the HDF5 files is shuffled,
+ // and the ordering of data within any given HDF5 file is shuffled,
+ // but data between different files are not interleaved; all of a file's
+ // data are output (in a random order) before moving onto another file.
+ optional bool shuffle = 3 [default = false];
}
// Message that stores parameters used by HDF5OutputLayer
@@ -946,3 +954,14 @@ message V0LayerParameter {
optional HDF5OutputParameter hdf5_output_param = 1001;
}
+
+// Message that stores parameters used by PReLULayer
+message PReLUParameter {
+ // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers:
+ // Surpassing Human-Level Performance on ImageNet Classification, 2015.
+
+ // Initial value of a_i. Default is a_i=0.25 for all i.
+ optional FillerParameter filler = 1;
+ // Whether or not slope paramters are shared across channels.
+ optional bool channel_shared = 2 [default = false];
+}
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 034390e6824..096980dd7af 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -349,7 +349,7 @@ void Solver::Restore(const char* state_file) {
NetParameter net_param;
ReadProtoFromBinaryFile(state_file, &state);
if (state.has_learned_net()) {
- ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
+ ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
net_->CopyTrainedLayersFrom(net_param);
}
iter_ = state.iter();
diff --git a/src/caffe/test/test_lrn_layer.cpp b/src/caffe/test/test_lrn_layer.cpp
index 07425df9b3a..c4e2f8ea7f2 100644
--- a/src/caffe/test/test_lrn_layer.cpp
+++ b/src/caffe/test/test_lrn_layer.cpp
@@ -138,6 +138,22 @@ TYPED_TEST(LRNLayerTest, TestForwardAcrossChannels) {
}
}
+TYPED_TEST(LRNLayerTest, TestForwardAcrossChannelsLargeRegion) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ layer_param.mutable_lrn_param()->set_local_size(15);
+ LRNLayer layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ Blob top_reference;
+ this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+ &top_reference);
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
+ this->epsilon_);
+ }
+}
+
TYPED_TEST(LRNLayerTest, TestGradientAcrossChannels) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
@@ -159,6 +175,28 @@ TYPED_TEST(LRNLayerTest, TestGradientAcrossChannels) {
this->blob_top_vec_);
}
+TYPED_TEST(LRNLayerTest, TestGradientAcrossChannelsLargeRegion) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ layer_param.mutable_lrn_param()->set_local_size(15);
+ LRNLayer layer(layer_param);
+ GradientChecker checker(1e-2, 1e-2);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ for (int i = 0; i < this->blob_top_->count(); ++i) {
+ this->blob_top_->mutable_cpu_diff()[i] = 1.;
+ }
+ vector propagate_down(this->blob_bottom_vec_.size(), true);
+ layer.Backward(this->blob_top_vec_, propagate_down,
+ this->blob_bottom_vec_);
+ // for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ // std::cout << "CPU diff " << this->blob_bottom_->cpu_diff()[i]
+ // << std::endl;
+ // }
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
TYPED_TEST(LRNLayerTest, TestSetupWithinChannel) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp
index ad10720116d..c9d52f247a6 100644
--- a/src/caffe/test/test_neuron_layer.cpp
+++ b/src/caffe/test/test_neuron_layer.cpp
@@ -1,3 +1,4 @@
+#include
#include
#include
@@ -99,6 +100,23 @@ class NeuronLayerTest : public MultiDeviceTest {
GradientChecker checker(1e-2, 1e-3);
checker.CheckGradientEltwise(&layer, blob_bottom_vec_, blob_top_vec_);
}
+
+ void TestPReLU(PReLULayer *layer) {
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ const Dtype* slope_data = layer->blobs()[0]->cpu_data();
+ int hw = this->blob_bottom_->height() * this->blob_bottom_->width();
+ int channels = this->blob_bottom_->channels();
+ bool channel_shared = layer->layer_param().prelu_param().channel_shared();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ int c = channel_shared ? 0 : (i / hw) % channels;
+ EXPECT_EQ(top_data[i],
+ std::max(bottom_data[i], (Dtype)(0))
+ + slope_data[c] * std::min(bottom_data[i], (Dtype)(0)));
+ }
+ }
};
TYPED_TEST_CASE(NeuronLayerTest, TestDtypesAndDevices);
@@ -392,6 +410,184 @@ TYPED_TEST(NeuronLayerTest, TestBNLLGradient) {
this->blob_top_vec_);
}
+TYPED_TEST(NeuronLayerTest, TestPReLUParam) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ PReLULayer layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ const Dtype* slopes = layer.blobs()[0]->cpu_data();
+ int count = layer.blobs()[0]->count();
+ for (int i = 0; i < count; ++i, ++slopes) {
+ EXPECT_EQ(*slopes, 0.25);
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUForward) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ PReLULayer layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ FillerParameter filler_param;
+ GaussianFiller filler(filler_param);
+ filler.Fill(layer.blobs()[0].get());
+ this->TestPReLU(&layer);
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUForwardChannelShared) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ layer_param.mutable_prelu_param()->set_channel_shared(true);
+ PReLULayer layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ this->TestPReLU(&layer);
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ PReLULayer layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ FillerParameter filler_param;
+ GaussianFiller filler(filler_param);
+ filler.Fill(layer.blobs()[0].get());
+ GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUGradientChannelShared) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ layer_param.mutable_prelu_param()->set_channel_shared(true);
+ PReLULayer layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUConsistencyReLU) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter prelu_layer_param;
+ LayerParameter relu_layer_param;
+ relu_layer_param.mutable_relu_param()->set_negative_slope(0.25);
+ PReLULayer prelu(prelu_layer_param);
+ ReLULayer relu(relu_layer_param);
+ // Set up blobs
+ vector*> blob_bottom_vec_2;
+ vector*> blob_top_vec_2;
+ shared_ptr > blob_bottom_2(new Blob());
+ shared_ptr > blob_top_2(new Blob());
+ blob_bottom_vec_2.push_back(blob_bottom_2.get());
+ blob_top_vec_2.push_back(blob_top_2.get());
+ blob_bottom_2->CopyFrom(*this->blob_bottom_, false, true);
+ // SetUp layers
+ prelu.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ relu.SetUp(blob_bottom_vec_2, blob_top_vec_2);
+ // Check forward
+ prelu.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ relu.Forward(this->blob_bottom_vec_, blob_top_vec_2);
+ for (int s = 0; s < blob_top_2->count(); ++s) {
+ EXPECT_EQ(this->blob_top_->cpu_data()[s], blob_top_2->cpu_data()[s]);
+ }
+ // Check backward
+ shared_ptr > tmp_blob(new Blob());
+ tmp_blob->ReshapeLike(*blob_top_2.get());
+ FillerParameter filler_param;
+ GaussianFiller filler(filler_param);
+ filler.Fill(tmp_blob.get());
+ caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
+ this->blob_top_->mutable_cpu_diff());
+ caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
+ blob_top_2->mutable_cpu_diff());
+ vector propagate_down;
+ propagate_down.push_back(true);
+ prelu.Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_);
+ relu.Backward(blob_top_vec_2, propagate_down, blob_bottom_vec_2);
+ for (int s = 0; s < blob_bottom_2->count(); ++s) {
+ EXPECT_EQ(this->blob_bottom_->cpu_diff()[s], blob_bottom_2->cpu_diff()[s]);
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUInPlace) {
+ typedef typename TypeParam::Dtype Dtype;
+ // Set layer parameters
+ LayerParameter ip_layer_param;
+ LayerParameter prelu_layer_param;
+ InnerProductParameter *ip_param =
+ ip_layer_param.mutable_inner_product_param();
+ ip_param->mutable_weight_filler()->set_type("gaussian");
+ ip_param->set_num_output(3);
+ InnerProductLayer ip(ip_layer_param);
+ PReLULayer prelu(prelu_layer_param);
+ InnerProductLayer ip2(ip_layer_param);
+ PReLULayer prelu2(prelu_layer_param);
+ // Set up blobs
+ vector*> blob_bottom_vec_2;
+ vector*> blob_middle_vec_2;
+ vector*> blob_top_vec_2;
+ shared_ptr > blob_bottom_2(new Blob());
+ shared_ptr > blob_middle_2(new Blob());
+ shared_ptr > blob_top_2(new Blob());
+ blob_bottom_vec_2.push_back(blob_bottom_2.get());
+ blob_middle_vec_2.push_back(blob_middle_2.get());
+ blob_top_vec_2.push_back(blob_top_2.get());
+ blob_bottom_2->CopyFrom(*this->blob_bottom_, false, true);
+ // SetUp layers
+ ip.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ prelu.SetUp(this->blob_top_vec_, this->blob_top_vec_);
+ ip2.SetUp(blob_bottom_vec_2, blob_middle_vec_2);
+ prelu2.SetUp(blob_middle_vec_2, blob_top_vec_2);
+ caffe_copy(ip2.blobs()[0]->count(), ip.blobs()[0]->cpu_data(),
+ ip2.blobs()[0]->mutable_cpu_data());
+ // Forward in-place
+ ip.Reshape(this->blob_bottom_vec_, this->blob_top_vec_);
+ ip.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ prelu.Reshape(this->blob_top_vec_, this->blob_top_vec_);
+ prelu.Forward(this->blob_top_vec_, this->blob_top_vec_);
+ // Forward non-in-place
+ ip2.Reshape(blob_bottom_vec_2, blob_middle_vec_2);
+ ip2.Forward(blob_bottom_vec_2, blob_middle_vec_2);
+ prelu2.Reshape(blob_middle_vec_2, blob_top_vec_2);
+ prelu2.Forward(blob_middle_vec_2, blob_top_vec_2);
+ // Check numbers
+ for (int s = 0; s < blob_top_2->count(); ++s) {
+ EXPECT_EQ(this->blob_top_->cpu_data()[s], blob_top_2->cpu_data()[s]);
+ }
+ // Fill top diff with random numbers
+ shared_ptr > tmp_blob(new Blob());
+ tmp_blob->ReshapeLike(*blob_top_2.get());
+ FillerParameter filler_param;
+ GaussianFiller filler(filler_param);
+ filler.Fill(tmp_blob.get());
+ caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
+ this->blob_top_->mutable_cpu_diff());
+ caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
+ blob_top_2->mutable_cpu_diff());
+ // Backward in-place
+ vector propagate_down;
+ propagate_down.push_back(true);
+ prelu.Backward(this->blob_top_vec_, propagate_down, this->blob_top_vec_);
+ ip.Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_);
+ // Backward non-in-place
+ prelu2.Backward(blob_top_vec_2, propagate_down, blob_middle_vec_2);
+ ip2.Backward(blob_middle_vec_2, propagate_down, blob_bottom_vec_2);
+ // Check numbers
+ for (int s = 0; s < blob_bottom_2->count(); ++s) {
+ EXPECT_EQ(this->blob_bottom_->cpu_diff()[s], blob_bottom_2->cpu_diff()[s]);
+ }
+ for (int s = 0; s < ip.blobs()[0]->count(); ++s) {
+ EXPECT_EQ(ip.blobs()[0]->cpu_diff()[s], ip2.blobs()[0]->cpu_diff()[s]);
+ }
+ for (int s = 0; s < ip.blobs()[1]->count(); ++s) {
+ EXPECT_EQ(ip.blobs()[1]->cpu_diff()[s], ip2.blobs()[1]->cpu_diff()[s]);
+ }
+ for (int s = 0; s < prelu.blobs()[0]->count(); ++s) {
+ EXPECT_EQ(prelu.blobs()[0]->cpu_diff()[s],
+ prelu2.blobs()[0]->cpu_diff()[s]);
+ }
+}
+
#ifdef USE_CUDNN
template
class CuDNNNeuronLayerTest : public ::testing::Test {
diff --git a/src/caffe/test/test_pooling_layer.cpp b/src/caffe/test/test_pooling_layer.cpp
index 435caa8381e..e9964e7f0b7 100644
--- a/src/caffe/test/test_pooling_layer.cpp
+++ b/src/caffe/test/test_pooling_layer.cpp
@@ -976,9 +976,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestSetupCuDNN) {
EXPECT_EQ(this->blob_top_->width(), 2);
}
-// This test and all following cuDNN pooling tests with padding are commented
-// for now, since cuDNN pooling does not currently support padding.
-/*
TYPED_TEST(CuDNNPoolingLayerTest, TestSetupPaddedCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
@@ -994,7 +991,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestSetupPaddedCuDNN) {
EXPECT_EQ(this->blob_top_->height(), 4);
EXPECT_EQ(this->blob_top_->width(), 3);
}
-*/
/*
TYPED_TEST(CuDNNPoolingLayerTest, PrintBackwardCuDNN) {
@@ -1062,7 +1058,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxCuDNN) {
}
}
-/*
TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxPaddedCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
@@ -1107,7 +1102,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxPaddedCuDNN) {
EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4, epsilon);
EXPECT_NEAR(this->blob_top_->cpu_data()[8], 1, epsilon);
}
-*/
/*
TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxTopMaskCuDNN) {
@@ -1175,7 +1169,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAveCuDNN) {
}
}
-/*
TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAvePaddedCuDNN) {
Caffe::set_mode(Caffe::GPU);
for (int kernel_h = 3; kernel_h <= 4; kernel_h++) {
@@ -1194,7 +1187,6 @@ TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAvePaddedCuDNN) {
}
}
}
-*/
#endif
diff --git a/src/caffe/util/cudnn.cpp b/src/caffe/util/cudnn.cpp
new file mode 100644
index 00000000000..1772f0099ce
--- /dev/null
+++ b/src/caffe/util/cudnn.cpp
@@ -0,0 +1,23 @@
+#ifdef USE_CUDNN
+#include "caffe/util/cudnn.hpp"
+
+namespace caffe {
+namespace cudnn {
+
+float dataType::oneval = 1.0;
+float dataType::zeroval = 0.0;
+const void* dataType::one =
+ static_cast(&dataType::oneval);
+const void* dataType::zero =
+ static_cast(&dataType::zeroval);
+
+double dataType::oneval = 1.0;
+double dataType