From 5f2d845fafc8883aa16b437b79fa52b39f8a0ddb Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Sun, 15 Feb 2015 14:28:01 -0800 Subject: [PATCH 1/3] Add RecurrentLayer: an abstract superclass for other recurrent layer types --- include/caffe/layers/recurrent_layer.hpp | 187 ++++++++++++++ src/caffe/layers/recurrent_layer.cpp | 295 +++++++++++++++++++++++ src/caffe/layers/recurrent_layer.cu | 44 ++++ src/caffe/proto/caffe.proto | 22 +- 4 files changed, 547 insertions(+), 1 deletion(-) create mode 100644 include/caffe/layers/recurrent_layer.hpp create mode 100644 src/caffe/layers/recurrent_layer.cpp create mode 100644 src/caffe/layers/recurrent_layer.cu diff --git a/include/caffe/layers/recurrent_layer.hpp b/include/caffe/layers/recurrent_layer.hpp new file mode 100644 index 00000000000..ca17371b994 --- /dev/null +++ b/include/caffe/layers/recurrent_layer.hpp @@ -0,0 +1,187 @@ +#ifndef CAFFE_RECURRENT_LAYER_HPP_ +#define CAFFE_RECURRENT_LAYER_HPP_ + +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/format.hpp" + +namespace caffe { + +template class RecurrentLayer; + +/** + * @brief An abstract class for implementing recurrent behavior inside of an + * unrolled network. This Layer type cannot be instantiated -- instead, + * you should use one of its implementations which defines the recurrent + * architecture, such as RNNLayer or LSTMLayer. + */ +template +class RecurrentLayer : public Layer { + public: + explicit RecurrentLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual void Reset(); + + virtual inline const char* type() const { return "Recurrent"; } + virtual inline int MinBottomBlobs() const { + int min_bottoms = 2; + if (this->layer_param_.recurrent_param().expose_hidden()) { + vector inputs; + this->RecurrentInputBlobNames(&inputs); + min_bottoms += inputs.size(); + } + return min_bottoms; + } + virtual inline int MaxBottomBlobs() const { return MinBottomBlobs() + 1; } + virtual inline int ExactNumTopBlobs() const { + int num_tops = 1; + if (this->layer_param_.recurrent_param().expose_hidden()) { + vector outputs; + this->RecurrentOutputBlobNames(&outputs); + num_tops += outputs.size(); + } + return num_tops; + } + + virtual inline bool AllowForceBackward(const int bottom_index) const { + // Can't propagate to sequence continuation indicators. + return bottom_index != 1; + } + + protected: + /** + * @brief Fills net_param with the recurrent network architecture. Subclasses + * should define this -- see RNNLayer and LSTMLayer for examples. + */ + virtual void FillUnrolledNet(NetParameter* net_param) const = 0; + + /** + * @brief Fills names with the names of the 0th timestep recurrent input + * Blob&s. Subclasses should define this -- see RNNLayer and LSTMLayer + * for examples. + */ + virtual void RecurrentInputBlobNames(vector* names) const = 0; + + /** + * @brief Fills shapes with the shapes of the recurrent input Blob&s. + * Subclasses should define this -- see RNNLayer and LSTMLayer + * for examples. + */ + virtual void RecurrentInputShapes(vector* shapes) const = 0; + + /** + * @brief Fills names with the names of the Tth timestep recurrent output + * Blob&s. Subclasses should define this -- see RNNLayer and LSTMLayer + * for examples. + */ + virtual void RecurrentOutputBlobNames(vector* names) const = 0; + + /** + * @brief Fills names with the names of the output blobs, concatenated across + * all timesteps. Should return a name for each top Blob. + * Subclasses should define this -- see RNNLayer and LSTMLayer for + * examples. + */ + virtual void OutputBlobNames(vector* names) const = 0; + + /** + * @param bottom input Blob vector (length 2-3) + * + * -# @f$ (T \times N \times ...) @f$ + * the time-varying input @f$ x @f$. After the first two axes, whose + * dimensions must correspond to the number of timesteps @f$ T @f$ and + * the number of independent streams @f$ N @f$, respectively, its + * dimensions may be arbitrary. Note that the ordering of dimensions -- + * @f$ (T \times N \times ...) @f$, rather than + * @f$ (N \times T \times ...) @f$ -- means that the @f$ N @f$ + * independent input streams must be "interleaved". + * + * -# @f$ (T \times N) @f$ + * the sequence continuation indicators @f$ \delta @f$. + * These inputs should be binary (0 or 1) indicators, where + * @f$ \delta_{t,n} = 0 @f$ means that timestep @f$ t @f$ of stream + * @f$ n @f$ is the beginning of a new sequence, and hence the previous + * hidden state @f$ h_{t-1} @f$ is multiplied by @f$ \delta_t = 0 @f$ + * and has no effect on the cell's output at timestep @f$ t @f$, and + * a value of @f$ \delta_{t,n} = 1 @f$ means that timestep @f$ t @f$ of + * stream @f$ n @f$ is a continuation from the previous timestep + * @f$ t-1 @f$, and the previous hidden state @f$ h_{t-1} @f$ affects the + * updated hidden state and output. + * + * -# @f$ (N \times ...) @f$ (optional) + * the static (non-time-varying) input @f$ x_{static} @f$. + * After the first axis, whose dimension must be the number of + * independent streams, its dimensions may be arbitrary. + * This is mathematically equivalent to using a time-varying input of + * @f$ x'_t = [x_t; x_{static}] @f$ -- i.e., tiling the static input + * across the @f$ T @f$ timesteps and concatenating with the time-varying + * input. Note that if this input is used, all timesteps in a single + * batch within a particular one of the @f$ N @f$ streams must share the + * same static input, even if the sequence continuation indicators + * suggest that difference sequences are ending and beginning within a + * single batch. This may require padding and/or truncation for uniform + * length. + * + * @param top output Blob vector (length 1) + * -# @f$ (T \times N \times D) @f$ + * the time-varying output @f$ y @f$, where @f$ D @f$ is + * recurrent_param.num_output(). + * Refer to documentation for particular RecurrentLayer implementations + * (such as RNNLayer and LSTMLayer) for the definition of @f$ y @f$. + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// @brief A Net to implement the Recurrent functionality. + shared_ptr > unrolled_net_; + + /// @brief The number of independent streams to process simultaneously. + int N_; + + /** + * @brief The number of timesteps in the layer's input, and the number of + * timesteps over which to backpropagate through time. + */ + int T_; + + /// @brief Whether the layer has a "static" input copied across all timesteps. + bool static_input_; + + /** + * @brief The last layer to run in the network. (Any later layers are losses + * added to force the recurrent net to do backprop.) + */ + int last_layer_index_; + + /** + * @brief Whether the layer's hidden state at the first and last timesteps + * are layer inputs and outputs, respectively. + */ + bool expose_hidden_; + + vector* > recur_input_blobs_; + vector* > recur_output_blobs_; + vector* > output_blobs_; + Blob* x_input_blob_; + Blob* x_static_input_blob_; + Blob* cont_input_blob_; +}; + +} // namespace caffe + +#endif // CAFFE_RECURRENT_LAYER_HPP_ diff --git a/src/caffe/layers/recurrent_layer.cpp b/src/caffe/layers/recurrent_layer.cpp new file mode 100644 index 00000000000..e0c82773392 --- /dev/null +++ b/src/caffe/layers/recurrent_layer.cpp @@ -0,0 +1,295 @@ +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/layers/recurrent_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void RecurrentLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + CHECK_GE(bottom[0]->num_axes(), 2) + << "bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)"; + T_ = bottom[0]->shape(0); + N_ = bottom[0]->shape(1); + LOG(INFO) << "Initializing recurrent layer: assuming input batch contains " + << T_ << " timesteps of " << N_ << " independent streams."; + + CHECK_EQ(bottom[1]->num_axes(), 2) + << "bottom[1] must have exactly 2 axes -- (#timesteps, #streams)"; + CHECK_EQ(T_, bottom[1]->shape(0)); + CHECK_EQ(N_, bottom[1]->shape(1)); + + // If expose_hidden is set, we take as input and produce as output + // the hidden state blobs at the first and last timesteps. + expose_hidden_ = this->layer_param_.recurrent_param().expose_hidden(); + + // Get (recurrent) input/output names. + vector output_names; + OutputBlobNames(&output_names); + vector recur_input_names; + RecurrentInputBlobNames(&recur_input_names); + vector recur_output_names; + RecurrentOutputBlobNames(&recur_output_names); + const int num_recur_blobs = recur_input_names.size(); + CHECK_EQ(num_recur_blobs, recur_output_names.size()); + + // If provided, bottom[2] is a static input to the recurrent net. + const int num_hidden_exposed = expose_hidden_ * num_recur_blobs; + static_input_ = (bottom.size() > 2 + num_hidden_exposed); + if (static_input_) { + CHECK_GE(bottom[2]->num_axes(), 1); + CHECK_EQ(N_, bottom[2]->shape(0)); + } + + // Create a NetParameter; setup the inputs that aren't unique to particular + // recurrent architectures. + NetParameter net_param; + + LayerParameter* input_layer_param = net_param.add_layer(); + input_layer_param->set_type("Input"); + InputParameter* input_param = input_layer_param->mutable_input_param(); + input_layer_param->add_top("x"); + BlobShape input_shape; + for (int i = 0; i < bottom[0]->num_axes(); ++i) { + input_shape.add_dim(bottom[0]->shape(i)); + } + input_param->add_shape()->CopyFrom(input_shape); + + input_shape.Clear(); + for (int i = 0; i < bottom[1]->num_axes(); ++i) { + input_shape.add_dim(bottom[1]->shape(i)); + } + input_layer_param->add_top("cont"); + input_param->add_shape()->CopyFrom(input_shape); + + if (static_input_) { + input_shape.Clear(); + for (int i = 0; i < bottom[2]->num_axes(); ++i) { + input_shape.add_dim(bottom[2]->shape(i)); + } + input_layer_param->add_top("x_static"); + input_param->add_shape()->CopyFrom(input_shape); + } + + // Call the child's FillUnrolledNet implementation to specify the unrolled + // recurrent architecture. + this->FillUnrolledNet(&net_param); + + // Prepend this layer's name to the names of each layer in the unrolled net. + const string& layer_name = this->layer_param_.name(); + if (layer_name.size()) { + for (int i = 0; i < net_param.layer_size(); ++i) { + LayerParameter* layer = net_param.mutable_layer(i); + layer->set_name(layer_name + "_" + layer->name()); + } + } + + // Add "pseudo-losses" to all outputs to force backpropagation. + // (Setting force_backward is too aggressive as we may not need to backprop to + // all inputs, e.g., the sequence continuation indicators.) + vector pseudo_losses(output_names.size()); + for (int i = 0; i < output_names.size(); ++i) { + LayerParameter* layer = net_param.add_layer(); + pseudo_losses[i] = output_names[i] + "_pseudoloss"; + layer->set_name(pseudo_losses[i]); + layer->set_type("Reduction"); + layer->add_bottom(output_names[i]); + layer->add_top(pseudo_losses[i]); + layer->add_loss_weight(1); + } + + // Create the unrolled net. + unrolled_net_.reset(new Net(net_param)); + unrolled_net_->set_debug_info( + this->layer_param_.recurrent_param().debug_info()); + + // Setup pointers to the inputs. + x_input_blob_ = CHECK_NOTNULL(unrolled_net_->blob_by_name("x").get()); + cont_input_blob_ = CHECK_NOTNULL(unrolled_net_->blob_by_name("cont").get()); + if (static_input_) { + x_static_input_blob_ = + CHECK_NOTNULL(unrolled_net_->blob_by_name("x_static").get()); + } + + // Setup pointers to paired recurrent inputs/outputs. + recur_input_blobs_.resize(num_recur_blobs); + recur_output_blobs_.resize(num_recur_blobs); + for (int i = 0; i < recur_input_names.size(); ++i) { + recur_input_blobs_[i] = + CHECK_NOTNULL(unrolled_net_->blob_by_name(recur_input_names[i]).get()); + recur_output_blobs_[i] = + CHECK_NOTNULL(unrolled_net_->blob_by_name(recur_output_names[i]).get()); + } + + // Setup pointers to outputs. + CHECK_EQ(top.size() - num_hidden_exposed, output_names.size()) + << "OutputBlobNames must provide an output blob name for each top."; + output_blobs_.resize(output_names.size()); + for (int i = 0; i < output_names.size(); ++i) { + output_blobs_[i] = + CHECK_NOTNULL(unrolled_net_->blob_by_name(output_names[i]).get()); + } + + // We should have 2 inputs (x and cont), plus a number of recurrent inputs, + // plus maybe a static input. + CHECK_EQ(2 + num_recur_blobs + static_input_, + unrolled_net_->input_blobs().size()); + + // This layer's parameters are any parameters in the layers of the unrolled + // net. We only want one copy of each parameter, so check that the parameter + // is "owned" by the layer, rather than shared with another. + this->blobs_.clear(); + for (int i = 0; i < unrolled_net_->params().size(); ++i) { + if (unrolled_net_->param_owners()[i] == -1) { + LOG(INFO) << "Adding parameter " << i << ": " + << unrolled_net_->param_display_names()[i]; + this->blobs_.push_back(unrolled_net_->params()[i]); + } + } + // Check that param_propagate_down is set for all of the parameters in the + // unrolled net; set param_propagate_down to true in this layer. + for (int i = 0; i < unrolled_net_->layers().size(); ++i) { + for (int j = 0; j < unrolled_net_->layers()[i]->blobs().size(); ++j) { + CHECK(unrolled_net_->layers()[i]->param_propagate_down(j)) + << "param_propagate_down not set for layer " << i << ", param " << j; + } + } + this->param_propagate_down_.clear(); + this->param_propagate_down_.resize(this->blobs_.size(), true); + + // Set the diffs of recurrent outputs to 0 -- we can't backpropagate across + // batches. + for (int i = 0; i < recur_output_blobs_.size(); ++i) { + caffe_set(recur_output_blobs_[i]->count(), Dtype(0), + recur_output_blobs_[i]->mutable_cpu_diff()); + } + + // Check that the last output_names.size() layers are the pseudo-losses; + // set last_layer_index so that we don't actually run these layers. + const vector& layer_names = unrolled_net_->layer_names(); + last_layer_index_ = layer_names.size() - 1 - pseudo_losses.size(); + for (int i = last_layer_index_ + 1, j = 0; i < layer_names.size(); ++i, ++j) { + CHECK_EQ(layer_names[i], pseudo_losses[j]); + } +} + +template +void RecurrentLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + CHECK_GE(bottom[0]->num_axes(), 2) + << "bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)"; + CHECK_EQ(T_, bottom[0]->shape(0)) << "input number of timesteps changed"; + N_ = bottom[0]->shape(1); + CHECK_EQ(bottom[1]->num_axes(), 2) + << "bottom[1] must have exactly 2 axes -- (#timesteps, #streams)"; + CHECK_EQ(T_, bottom[1]->shape(0)); + CHECK_EQ(N_, bottom[1]->shape(1)); + x_input_blob_->ReshapeLike(*bottom[0]); + vector cont_shape = bottom[1]->shape(); + cont_input_blob_->Reshape(cont_shape); + if (static_input_) { + x_static_input_blob_->ReshapeLike(*bottom[2]); + } + vector recur_input_shapes; + RecurrentInputShapes(&recur_input_shapes); + CHECK_EQ(recur_input_shapes.size(), recur_input_blobs_.size()); + for (int i = 0; i < recur_input_shapes.size(); ++i) { + recur_input_blobs_[i]->Reshape(recur_input_shapes[i]); + } + unrolled_net_->Reshape(); + x_input_blob_->ShareData(*bottom[0]); + x_input_blob_->ShareDiff(*bottom[0]); + cont_input_blob_->ShareData(*bottom[1]); + if (static_input_) { + x_static_input_blob_->ShareData(*bottom[2]); + x_static_input_blob_->ShareDiff(*bottom[2]); + } + if (expose_hidden_) { + const int bottom_offset = 2 + static_input_; + for (int i = bottom_offset, j = 0; i < bottom.size(); ++i, ++j) { + CHECK(recur_input_blobs_[j]->shape() == bottom[i]->shape()) + << "bottom[" << i << "] shape must match hidden state input shape: " + << recur_input_blobs_[j]->shape_string(); + recur_input_blobs_[j]->ShareData(*bottom[i]); + } + } + for (int i = 0; i < output_blobs_.size(); ++i) { + top[i]->ReshapeLike(*output_blobs_[i]); + top[i]->ShareData(*output_blobs_[i]); + top[i]->ShareDiff(*output_blobs_[i]); + } + if (expose_hidden_) { + const int top_offset = output_blobs_.size(); + for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) { + top[i]->ReshapeLike(*recur_output_blobs_[j]); + } + } +} + +template +void RecurrentLayer::Reset() { + // "Reset" the hidden state of the net by zeroing out all recurrent outputs. + for (int i = 0; i < recur_output_blobs_.size(); ++i) { + caffe_set(recur_output_blobs_[i]->count(), Dtype(0), + recur_output_blobs_[i]->mutable_cpu_data()); + } +} + +template +void RecurrentLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + // Hacky fix for test time: reshare all the internal shared blobs, which may + // currently point to a stale owner blob that was dropped when Solver::Test + // called test_net->ShareTrainedLayersWith(net_.get()). + // TODO: somehow make this work non-hackily. + if (this->phase_ == TEST) { + unrolled_net_->ShareWeights(); + } + + DCHECK_EQ(recur_input_blobs_.size(), recur_output_blobs_.size()); + if (!expose_hidden_) { + for (int i = 0; i < recur_input_blobs_.size(); ++i) { + const int count = recur_input_blobs_[i]->count(); + DCHECK_EQ(count, recur_output_blobs_[i]->count()); + const Dtype* timestep_T_data = recur_output_blobs_[i]->cpu_data(); + Dtype* timestep_0_data = recur_input_blobs_[i]->mutable_cpu_data(); + caffe_copy(count, timestep_T_data, timestep_0_data); + } + } + + unrolled_net_->ForwardTo(last_layer_index_); + + if (expose_hidden_) { + const int top_offset = output_blobs_.size(); + for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) { + top[i]->ShareData(*recur_output_blobs_[j]); + } + } +} + +template +void RecurrentLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + CHECK(!propagate_down[1]) << "Cannot backpropagate to sequence indicators."; + + // TODO: skip backpropagation to inputs and parameters inside the unrolled + // net according to propagate_down[0] and propagate_down[2]. For now just + // backprop to inputs and parameters unconditionally, as either the inputs or + // the parameters do need backward (or Net would have set + // layer_needs_backward_[i] == false for this layer). + unrolled_net_->BackwardFrom(last_layer_index_); +} + +#ifdef CPU_ONLY +STUB_GPU_FORWARD(RecurrentLayer, Forward); +#endif + +INSTANTIATE_CLASS(RecurrentLayer); + +} // namespace caffe diff --git a/src/caffe/layers/recurrent_layer.cu b/src/caffe/layers/recurrent_layer.cu new file mode 100644 index 00000000000..4dd2b0e2165 --- /dev/null +++ b/src/caffe/layers/recurrent_layer.cu @@ -0,0 +1,44 @@ +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/layers/recurrent_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void RecurrentLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + // Hacky fix for test time... reshare all the shared blobs. + // TODO: somehow make this work non-hackily. + if (this->phase_ == TEST) { + unrolled_net_->ShareWeights(); + } + + DCHECK_EQ(recur_input_blobs_.size(), recur_output_blobs_.size()); + if (!expose_hidden_) { + for (int i = 0; i < recur_input_blobs_.size(); ++i) { + const int count = recur_input_blobs_[i]->count(); + DCHECK_EQ(count, recur_output_blobs_[i]->count()); + const Dtype* timestep_T_data = recur_output_blobs_[i]->gpu_data(); + Dtype* timestep_0_data = recur_input_blobs_[i]->mutable_gpu_data(); + caffe_copy(count, timestep_T_data, timestep_0_data); + } + } + + unrolled_net_->ForwardTo(last_layer_index_); + + if (expose_hidden_) { + const int top_offset = output_blobs_.size(); + for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) { + top[i]->ShareData(*recur_output_blobs_[j]); + } + } +} + +INSTANTIATE_LAYER_GPU_FORWARD(RecurrentLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 15810718631..1556781cbc2 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -306,7 +306,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 146 (last added: parameter_param) +// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -390,6 +390,7 @@ message LayerParameter { optional PowerParameter power_param = 122; optional PReLUParameter prelu_param = 131; optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; optional ReductionParameter reduction_param = 136; optional ReLUParameter relu_param = 123; optional ReshapeParameter reshape_param = 133; @@ -928,6 +929,25 @@ message PythonParameter { optional bool share_in_parallel = 4 [default = false]; } +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + // Message that stores parameters used by ReductionLayer message ReductionParameter { enum ReductionOp { From cf5f369574dd51045c1c92625c0fe6694a031f2a Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Sun, 15 Feb 2015 14:56:50 -0800 Subject: [PATCH 2/3] Add RNNLayer, with tests --- include/caffe/layers/rnn_layer.hpp | 47 ++++++ src/caffe/layers/rnn_layer.cpp | 236 +++++++++++++++++++++++++++++ src/caffe/test/test_rnn_layer.cpp | 217 ++++++++++++++++++++++++++ 3 files changed, 500 insertions(+) create mode 100644 include/caffe/layers/rnn_layer.hpp create mode 100644 src/caffe/layers/rnn_layer.cpp create mode 100644 src/caffe/test/test_rnn_layer.cpp diff --git a/include/caffe/layers/rnn_layer.hpp b/include/caffe/layers/rnn_layer.hpp new file mode 100644 index 00000000000..6dce238ae17 --- /dev/null +++ b/include/caffe/layers/rnn_layer.hpp @@ -0,0 +1,47 @@ +#ifndef CAFFE_RNN_LAYER_HPP_ +#define CAFFE_RNN_LAYER_HPP_ + +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/layers/recurrent_layer.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template class RecurrentLayer; + +/** + * @brief Processes time-varying inputs using a simple recurrent neural network + * (RNN). Implemented as a network unrolling the RNN computation in time. + * + * Given time-varying inputs @f$ x_t @f$, computes hidden state @f$ + * h_t := \tanh[ W_{hh} h_{t_1} + W_{xh} x_t + b_h ] + * @f$, and outputs @f$ + * o_t := \tanh[ W_{ho} h_t + b_o ] + * @f$. + */ +template +class RNNLayer : public RecurrentLayer { + public: + explicit RNNLayer(const LayerParameter& param) + : RecurrentLayer(param) {} + + virtual inline const char* type() const { return "RNN"; } + + protected: + virtual void FillUnrolledNet(NetParameter* net_param) const; + virtual void RecurrentInputBlobNames(vector* names) const; + virtual void RecurrentOutputBlobNames(vector* names) const; + virtual void RecurrentInputShapes(vector* shapes) const; + virtual void OutputBlobNames(vector* names) const; +}; + +} // namespace caffe + +#endif // CAFFE_RNN_LAYER_HPP_ diff --git a/src/caffe/layers/rnn_layer.cpp b/src/caffe/layers/rnn_layer.cpp new file mode 100644 index 00000000000..f62ae8c77de --- /dev/null +++ b/src/caffe/layers/rnn_layer.cpp @@ -0,0 +1,236 @@ +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/layers/rnn_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void RNNLayer::RecurrentInputBlobNames(vector* names) const { + names->resize(1); + (*names)[0] = "h_0"; +} + +template +void RNNLayer::RecurrentOutputBlobNames(vector* names) const { + names->resize(1); + (*names)[0] = "h_" + format_int(this->T_); +} + +template +void RNNLayer::RecurrentInputShapes(vector* shapes) const { + const int num_output = this->layer_param_.recurrent_param().num_output(); + shapes->resize(1); + (*shapes)[0].Clear(); + (*shapes)[0].add_dim(1); // a single timestep + (*shapes)[0].add_dim(this->N_); + (*shapes)[0].add_dim(num_output); +} + +template +void RNNLayer::OutputBlobNames(vector* names) const { + names->resize(1); + (*names)[0] = "o"; +} + +template +void RNNLayer::FillUnrolledNet(NetParameter* net_param) const { + const int num_output = this->layer_param_.recurrent_param().num_output(); + CHECK_GT(num_output, 0) << "num_output must be positive"; + const FillerParameter& weight_filler = + this->layer_param_.recurrent_param().weight_filler(); + const FillerParameter& bias_filler = + this->layer_param_.recurrent_param().bias_filler(); + + // Add generic LayerParameter's (without bottoms/tops) of layer types we'll + // use to save redundant code. + LayerParameter hidden_param; + hidden_param.set_type("InnerProduct"); + hidden_param.mutable_inner_product_param()->set_num_output(num_output); + hidden_param.mutable_inner_product_param()->set_bias_term(false); + hidden_param.mutable_inner_product_param()->set_axis(2); + hidden_param.mutable_inner_product_param()-> + mutable_weight_filler()->CopyFrom(weight_filler); + + LayerParameter biased_hidden_param(hidden_param); + biased_hidden_param.mutable_inner_product_param()->set_bias_term(true); + biased_hidden_param.mutable_inner_product_param()-> + mutable_bias_filler()->CopyFrom(bias_filler); + + LayerParameter sum_param; + sum_param.set_type("Eltwise"); + sum_param.mutable_eltwise_param()->set_operation( + EltwiseParameter_EltwiseOp_SUM); + + LayerParameter tanh_param; + tanh_param.set_type("TanH"); + + LayerParameter scale_param; + scale_param.set_type("Scale"); + scale_param.mutable_scale_param()->set_axis(0); + + LayerParameter slice_param; + slice_param.set_type("Slice"); + slice_param.mutable_slice_param()->set_axis(0); + + vector input_shapes; + RecurrentInputShapes(&input_shapes); + CHECK_EQ(1, input_shapes.size()); + + LayerParameter* input_layer_param = net_param->add_layer(); + input_layer_param->set_type("Input"); + InputParameter* input_param = input_layer_param->mutable_input_param(); + input_layer_param->add_top("h_0"); + input_param->add_shape()->CopyFrom(input_shapes[0]); + + LayerParameter* cont_slice_param = net_param->add_layer(); + cont_slice_param->CopyFrom(slice_param); + cont_slice_param->set_name("cont_slice"); + cont_slice_param->add_bottom("cont"); + cont_slice_param->mutable_slice_param()->set_axis(0); + + // Add layer to transform all timesteps of x to the hidden state dimension. + // W_xh_x = W_xh * x + b_h + { + LayerParameter* x_transform_param = net_param->add_layer(); + x_transform_param->CopyFrom(biased_hidden_param); + x_transform_param->set_name("x_transform"); + x_transform_param->add_param()->set_name("W_xh"); + x_transform_param->add_param()->set_name("b_h"); + x_transform_param->add_bottom("x"); + x_transform_param->add_top("W_xh_x"); + x_transform_param->add_propagate_down(true); + } + + if (this->static_input_) { + // Add layer to transform x_static to the hidden state dimension. + // W_xh_x_static = W_xh_static * x_static + LayerParameter* x_static_transform_param = net_param->add_layer(); + x_static_transform_param->CopyFrom(hidden_param); + x_static_transform_param->mutable_inner_product_param()->set_axis(1); + x_static_transform_param->set_name("W_xh_x_static"); + x_static_transform_param->add_param()->set_name("W_xh_static"); + x_static_transform_param->add_bottom("x_static"); + x_static_transform_param->add_top("W_xh_x_static_preshape"); + x_static_transform_param->add_propagate_down(true); + + LayerParameter* reshape_param = net_param->add_layer(); + reshape_param->set_type("Reshape"); + BlobShape* new_shape = + reshape_param->mutable_reshape_param()->mutable_shape(); + new_shape->add_dim(1); // One timestep. + // Should infer this->N as the dimension so we can reshape on batch size. + new_shape->add_dim(-1); + new_shape->add_dim( + x_static_transform_param->inner_product_param().num_output()); + reshape_param->set_name("W_xh_x_static_reshape"); + reshape_param->add_bottom("W_xh_x_static_preshape"); + reshape_param->add_top("W_xh_x_static"); + } + + LayerParameter* x_slice_param = net_param->add_layer(); + x_slice_param->CopyFrom(slice_param); + x_slice_param->set_name("W_xh_x_slice"); + x_slice_param->add_bottom("W_xh_x"); + + LayerParameter output_concat_layer; + output_concat_layer.set_name("o_concat"); + output_concat_layer.set_type("Concat"); + output_concat_layer.add_top("o"); + output_concat_layer.mutable_concat_param()->set_axis(0); + + for (int t = 1; t <= this->T_; ++t) { + string tm1s = format_int(t - 1); + string ts = format_int(t); + + cont_slice_param->add_top("cont_" + ts); + x_slice_param->add_top("W_xh_x_" + ts); + + // Add layer to flush the hidden state when beginning a new sequence, + // as indicated by cont_t. + // h_conted_{t-1} := cont_t * h_{t-1} + // + // Normally, cont_t is binary (i.e., 0 or 1), so: + // h_conted_{t-1} := h_{t-1} if cont_t == 1 + // 0 otherwise + { + LayerParameter* cont_h_param = net_param->add_layer(); + cont_h_param->CopyFrom(scale_param); + cont_h_param->set_name("h_conted_" + tm1s); + cont_h_param->add_bottom("h_" + tm1s); + cont_h_param->add_bottom("cont_" + ts); + cont_h_param->add_top("h_conted_" + tm1s); + } + + // Add layer to compute + // W_hh_h_{t-1} := W_hh * h_conted_{t-1} + { + LayerParameter* w_param = net_param->add_layer(); + w_param->CopyFrom(hidden_param); + w_param->set_name("W_hh_h_" + tm1s); + w_param->add_param()->set_name("W_hh"); + w_param->add_bottom("h_conted_" + tm1s); + w_param->add_top("W_hh_h_" + tm1s); + w_param->mutable_inner_product_param()->set_axis(2); + } + + // Add layers to compute + // h_t := \tanh( W_hh * h_conted_{t-1} + W_xh * x_t + b_h ) + // = \tanh( W_hh_h_{t-1} + W_xh_t ) + { + LayerParameter* h_input_sum_param = net_param->add_layer(); + h_input_sum_param->CopyFrom(sum_param); + h_input_sum_param->set_name("h_input_sum_" + ts); + h_input_sum_param->add_bottom("W_hh_h_" + tm1s); + h_input_sum_param->add_bottom("W_xh_x_" + ts); + if (this->static_input_) { + h_input_sum_param->add_bottom("W_xh_x_static"); + } + h_input_sum_param->add_top("h_neuron_input_" + ts); + } + { + LayerParameter* h_neuron_param = net_param->add_layer(); + h_neuron_param->CopyFrom(tanh_param); + h_neuron_param->set_name("h_neuron_" + ts); + h_neuron_param->add_bottom("h_neuron_input_" + ts); + h_neuron_param->add_top("h_" + ts); + } + + // Add layer to compute + // W_ho_h_t := W_ho * h_t + b_o + { + LayerParameter* w_param = net_param->add_layer(); + w_param->CopyFrom(biased_hidden_param); + w_param->set_name("W_ho_h_" + ts); + w_param->add_param()->set_name("W_ho"); + w_param->add_param()->set_name("b_o"); + w_param->add_bottom("h_" + ts); + w_param->add_top("W_ho_h_" + ts); + w_param->mutable_inner_product_param()->set_axis(2); + } + + // Add layers to compute + // o_t := \tanh( W_ho h_t + b_o) + // = \tanh( W_ho_h_t ) + { + LayerParameter* o_neuron_param = net_param->add_layer(); + o_neuron_param->CopyFrom(tanh_param); + o_neuron_param->set_name("o_neuron_" + ts); + o_neuron_param->add_bottom("W_ho_h_" + ts); + o_neuron_param->add_top("o_" + ts); + } + output_concat_layer.add_bottom("o_" + ts); + } // for (int t = 1; t <= this->T_; ++t) + + net_param->add_layer()->CopyFrom(output_concat_layer); +} + +INSTANTIATE_CLASS(RNNLayer); +REGISTER_LAYER_CLASS(RNN); + +} // namespace caffe diff --git a/src/caffe/test/test_rnn_layer.cpp b/src/caffe/test/test_rnn_layer.cpp new file mode 100644 index 00000000000..dd8952d62d6 --- /dev/null +++ b/src/caffe/test/test_rnn_layer.cpp @@ -0,0 +1,217 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/rnn_layer.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class RNNLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + RNNLayerTest() : num_output_(7) { + blob_bottom_vec_.push_back(&blob_bottom_); + blob_bottom_vec_.push_back(&blob_bottom_cont_); + blob_top_vec_.push_back(&blob_top_); + + ReshapeBlobs(1, 3); + + layer_param_.mutable_recurrent_param()->set_num_output(num_output_); + FillerParameter* weight_filler = + layer_param_.mutable_recurrent_param()->mutable_weight_filler(); + weight_filler->set_type("gaussian"); + weight_filler->set_std(0.2); + FillerParameter* bias_filler = + layer_param_.mutable_recurrent_param()->mutable_bias_filler(); + bias_filler->set_type("gaussian"); + bias_filler->set_std(0.1); + + layer_param_.set_phase(TEST); + } + + void ReshapeBlobs(int num_timesteps, int num_instances) { + blob_bottom_.Reshape(num_timesteps, num_instances, 3, 2); + blob_bottom_static_.Reshape(num_instances, 2, 3, 4); + vector shape(2); + shape[0] = num_timesteps; + shape[1] = num_instances; + blob_bottom_cont_.Reshape(shape); + + FillerParameter filler_param; + filler_param.set_min(-1); + filler_param.set_max(1); + UniformFiller filler(filler_param); + filler.Fill(&blob_bottom_); + } + + int num_output_; + LayerParameter layer_param_; + Blob blob_bottom_; + Blob blob_bottom_cont_; + Blob blob_bottom_static_; + Blob blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(RNNLayerTest, TestDtypesAndDevices); + +TYPED_TEST(RNNLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + RNNLayer layer(this->layer_param_); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + vector expected_top_shape = this->blob_bottom_.shape(); + expected_top_shape.resize(3); + expected_top_shape[2] = this->num_output_; + EXPECT_TRUE(this->blob_top_.shape() == expected_top_shape); +} + +TYPED_TEST(RNNLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + const int kNumTimesteps = 3; + const int num = this->blob_bottom_.shape(1); + this->ReshapeBlobs(kNumTimesteps, num); + + // Fill the cont blob with <0, 1, 1, ..., 1>, + // indicating a sequence that begins at the first timestep + // then continues for the rest of the sequence. + for (int t = 0; t < kNumTimesteps; ++t) { + for (int n = 0; n < num; ++n) { + this->blob_bottom_cont_.mutable_cpu_data()[t * num + n] = t > 0; + } + } + + // Process the full sequence in a single batch. + FillerParameter filler_param; + filler_param.set_mean(0); + filler_param.set_std(1); + GaussianFiller sequence_filler(filler_param); + sequence_filler.Fill(&this->blob_bottom_); + shared_ptr > layer(new RNNLayer(this->layer_param_)); + Caffe::set_random_seed(1701); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + LOG(INFO) << "Calling forward for full sequence RNN"; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Copy the inputs and outputs to reuse/check them later. + Blob bottom_copy(this->blob_bottom_.shape()); + bottom_copy.CopyFrom(this->blob_bottom_); + Blob top_copy(this->blob_top_.shape()); + top_copy.CopyFrom(this->blob_top_); + + // Process the batch one timestep at a time; + // check that we get the same result. + this->ReshapeBlobs(1, num); + layer.reset(new RNNLayer(this->layer_param_)); + Caffe::set_random_seed(1701); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const int bottom_count = this->blob_bottom_.count(); + const int top_count = this->blob_top_.count(); + const Dtype kEpsilon = 1e-5; + for (int t = 0; t < kNumTimesteps; ++t) { + caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count, + this->blob_bottom_.mutable_cpu_data()); + for (int n = 0; n < num; ++n) { + this->blob_bottom_cont_.mutable_cpu_data()[n] = t > 0; + } + LOG(INFO) << "Calling forward for RNN timestep " << t; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < top_count; ++i) { + ASSERT_LT(t * top_count + i, top_copy.count()); + EXPECT_NEAR(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i], kEpsilon) + << "t = " << t << "; i = " << i; + } + } + + // Process the batch one timestep at a time with all cont blobs set to 0. + // Check that we get a different result, except in the first timestep. + Caffe::set_random_seed(1701); + layer.reset(new RNNLayer(this->layer_param_)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + for (int t = 0; t < kNumTimesteps; ++t) { + caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count, + this->blob_bottom_.mutable_cpu_data()); + for (int n = 0; n < num; ++n) { + this->blob_bottom_cont_.mutable_cpu_data()[n] = 0; + } + LOG(INFO) << "Calling forward for RNN timestep " << t; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < top_count; ++i) { + if (t == 0) { + EXPECT_NEAR(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i], kEpsilon) + << "t = " << t << "; i = " << i; + } else { + EXPECT_NE(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i]) + << "t = " << t << "; i = " << i; + } + } + } +} + +TYPED_TEST(RNNLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + RNNLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(RNNLayerTest, TestGradientNonZeroCont) { + typedef typename TypeParam::Dtype Dtype; + RNNLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) { + this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(RNNLayerTest, TestGradientNonZeroContBufferSize2) { + typedef typename TypeParam::Dtype Dtype; + this->ReshapeBlobs(2, 2); + // fill the values + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(&this->blob_bottom_); + RNNLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) { + this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(RNNLayerTest, TestGradientNonZeroContBufferSize2WithStaticInput) { + typedef typename TypeParam::Dtype Dtype; + this->ReshapeBlobs(2, 2); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(&this->blob_bottom_); + filler.Fill(&this->blob_bottom_static_); + this->blob_bottom_vec_.push_back(&this->blob_bottom_static_); + RNNLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) { + this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 2); +} + +} // namespace caffe From 51a68f0a0e9e376597d7cabae709ff969ad30c98 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Tue, 5 Apr 2016 09:56:04 -0700 Subject: [PATCH 3/3] Add LSTMLayer and LSTMUnitLayer, with tests --- include/caffe/layers/lstm_layer.hpp | 154 ++++++++++++++ src/caffe/layers/lstm_layer.cpp | 244 +++++++++++++++++++++++ src/caffe/layers/lstm_unit_layer.cpp | 131 ++++++++++++ src/caffe/layers/lstm_unit_layer.cu | 154 ++++++++++++++ src/caffe/test/test_lstm_layer.cpp | 288 +++++++++++++++++++++++++++ 5 files changed, 971 insertions(+) create mode 100644 include/caffe/layers/lstm_layer.hpp create mode 100644 src/caffe/layers/lstm_layer.cpp create mode 100644 src/caffe/layers/lstm_unit_layer.cpp create mode 100644 src/caffe/layers/lstm_unit_layer.cu create mode 100644 src/caffe/test/test_lstm_layer.cpp diff --git a/include/caffe/layers/lstm_layer.hpp b/include/caffe/layers/lstm_layer.hpp new file mode 100644 index 00000000000..a0e67c9d432 --- /dev/null +++ b/include/caffe/layers/lstm_layer.hpp @@ -0,0 +1,154 @@ +#ifndef CAFFE_LSTM_LAYER_HPP_ +#define CAFFE_LSTM_LAYER_HPP_ + +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/layers/recurrent_layer.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template class RecurrentLayer; + +/** + * @brief Processes sequential inputs using a "Long Short-Term Memory" (LSTM) + * [1] style recurrent neural network (RNN). Implemented by unrolling + * the LSTM computation through time. + * + * The specific architecture used in this implementation is as described in + * "Learning to Execute" [2], reproduced below: + * i_t := \sigmoid[ W_{hi} * h_{t-1} + W_{xi} * x_t + b_i ] + * f_t := \sigmoid[ W_{hf} * h_{t-1} + W_{xf} * x_t + b_f ] + * o_t := \sigmoid[ W_{ho} * h_{t-1} + W_{xo} * x_t + b_o ] + * g_t := \tanh[ W_{hg} * h_{t-1} + W_{xg} * x_t + b_g ] + * c_t := (f_t .* c_{t-1}) + (i_t .* g_t) + * h_t := o_t .* \tanh[c_t] + * In the implementation, the i, f, o, and g computations are performed as a + * single inner product. + * + * Notably, this implementation lacks the "diagonal" gates, as used in the + * LSTM architectures described by Alex Graves [3] and others. + * + * [1] Hochreiter, Sepp, and Schmidhuber, Jürgen. "Long short-term memory." + * Neural Computation 9, no. 8 (1997): 1735-1780. + * + * [2] Zaremba, Wojciech, and Sutskever, Ilya. "Learning to execute." + * arXiv preprint arXiv:1410.4615 (2014). + * + * [3] Graves, Alex. "Generating sequences with recurrent neural networks." + * arXiv preprint arXiv:1308.0850 (2013). + */ +template +class LSTMLayer : public RecurrentLayer { + public: + explicit LSTMLayer(const LayerParameter& param) + : RecurrentLayer(param) {} + + virtual inline const char* type() const { return "LSTM"; } + + protected: + virtual void FillUnrolledNet(NetParameter* net_param) const; + virtual void RecurrentInputBlobNames(vector* names) const; + virtual void RecurrentOutputBlobNames(vector* names) const; + virtual void RecurrentInputShapes(vector* shapes) const; + virtual void OutputBlobNames(vector* names) const; +}; + +/** + * @brief A helper for LSTMLayer: computes a single timestep of the + * non-linearity of the LSTM, producing the updated cell and hidden + * states. + */ +template +class LSTMUnitLayer : public Layer { + public: + explicit LSTMUnitLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "LSTMUnit"; } + virtual inline int ExactNumBottomBlobs() const { return 3; } + virtual inline int ExactNumTopBlobs() const { return 2; } + + virtual inline bool AllowForceBackward(const int bottom_index) const { + // Can't propagate to sequence continuation indicators. + return bottom_index != 2; + } + + protected: + /** + * @param bottom input Blob vector (length 3) + * -# @f$ (1 \times N \times D) @f$ + * the previous timestep cell state @f$ c_{t-1} @f$ + * -# @f$ (1 \times N \times 4D) @f$ + * the "gate inputs" @f$ [i_t', f_t', o_t', g_t'] @f$ + * -# @f$ (1 \times N) @f$ + * the sequence continuation indicators @f$ \delta_t @f$ + * @param top output Blob vector (length 2) + * -# @f$ (1 \times N \times D) @f$ + * the updated cell state @f$ c_t @f$, computed as: + * i_t := \sigmoid[i_t'] + * f_t := \sigmoid[f_t'] + * o_t := \sigmoid[o_t'] + * g_t := \tanh[g_t'] + * c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t) + * -# @f$ (1 \times N \times D) @f$ + * the updated hidden state @f$ h_t @f$, computed as: + * h_t := o_t .* \tanh[c_t] + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the LSTMUnit inputs. + * + * @param top output Blob vector (length 2), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times N \times D) @f$: + * containing error gradients @f$ \frac{\partial E}{\partial c_t} @f$ + * with respect to the updated cell state @f$ c_t @f$ + * -# @f$ (1 \times N \times D) @f$: + * containing error gradients @f$ \frac{\partial E}{\partial h_t} @f$ + * with respect to the updated cell state @f$ h_t @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 3), into which the error gradients + * with respect to the LSTMUnit inputs @f$ c_{t-1} @f$ and the gate + * inputs are computed. Computatation of the error gradients w.r.t. + * the sequence indicators is not implemented. + * -# @f$ (1 \times N \times D) @f$ + * the error gradient w.r.t. the previous timestep cell state + * @f$ c_{t-1} @f$ + * -# @f$ (1 \times N \times 4D) @f$ + * the error gradient w.r.t. the "gate inputs" + * @f$ [ + * \frac{\partial E}{\partial i_t} + * \frac{\partial E}{\partial f_t} + * \frac{\partial E}{\partial o_t} + * \frac{\partial E}{\partial g_t} + * ] @f$ + * -# @f$ (1 \times 1 \times N) @f$ + * the gradient w.r.t. the sequence continuation indicators + * @f$ \delta_t @f$ is currently not computed. + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// @brief The hidden and output dimension. + int hidden_dim_; + Blob X_acts_; +}; + +} // namespace caffe + +#endif // CAFFE_LSTM_LAYER_HPP_ diff --git a/src/caffe/layers/lstm_layer.cpp b/src/caffe/layers/lstm_layer.cpp new file mode 100644 index 00000000000..da48dba4c05 --- /dev/null +++ b/src/caffe/layers/lstm_layer.cpp @@ -0,0 +1,244 @@ +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/layers/lstm_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void LSTMLayer::RecurrentInputBlobNames(vector* names) const { + names->resize(2); + (*names)[0] = "h_0"; + (*names)[1] = "c_0"; +} + +template +void LSTMLayer::RecurrentOutputBlobNames(vector* names) const { + names->resize(2); + (*names)[0] = "h_" + format_int(this->T_); + (*names)[1] = "c_T"; +} + +template +void LSTMLayer::RecurrentInputShapes(vector* shapes) const { + const int num_output = this->layer_param_.recurrent_param().num_output(); + const int num_blobs = 2; + shapes->resize(num_blobs); + for (int i = 0; i < num_blobs; ++i) { + (*shapes)[i].Clear(); + (*shapes)[i].add_dim(1); // a single timestep + (*shapes)[i].add_dim(this->N_); + (*shapes)[i].add_dim(num_output); + } +} + +template +void LSTMLayer::OutputBlobNames(vector* names) const { + names->resize(1); + (*names)[0] = "h"; +} + +template +void LSTMLayer::FillUnrolledNet(NetParameter* net_param) const { + const int num_output = this->layer_param_.recurrent_param().num_output(); + CHECK_GT(num_output, 0) << "num_output must be positive"; + const FillerParameter& weight_filler = + this->layer_param_.recurrent_param().weight_filler(); + const FillerParameter& bias_filler = + this->layer_param_.recurrent_param().bias_filler(); + + // Add generic LayerParameter's (without bottoms/tops) of layer types we'll + // use to save redundant code. + LayerParameter hidden_param; + hidden_param.set_type("InnerProduct"); + hidden_param.mutable_inner_product_param()->set_num_output(num_output * 4); + hidden_param.mutable_inner_product_param()->set_bias_term(false); + hidden_param.mutable_inner_product_param()->set_axis(2); + hidden_param.mutable_inner_product_param()-> + mutable_weight_filler()->CopyFrom(weight_filler); + + LayerParameter biased_hidden_param(hidden_param); + biased_hidden_param.mutable_inner_product_param()->set_bias_term(true); + biased_hidden_param.mutable_inner_product_param()-> + mutable_bias_filler()->CopyFrom(bias_filler); + + LayerParameter sum_param; + sum_param.set_type("Eltwise"); + sum_param.mutable_eltwise_param()->set_operation( + EltwiseParameter_EltwiseOp_SUM); + + LayerParameter scale_param; + scale_param.set_type("Scale"); + scale_param.mutable_scale_param()->set_axis(0); + + LayerParameter slice_param; + slice_param.set_type("Slice"); + slice_param.mutable_slice_param()->set_axis(0); + + LayerParameter split_param; + split_param.set_type("Split"); + + vector input_shapes; + RecurrentInputShapes(&input_shapes); + CHECK_EQ(2, input_shapes.size()); + + LayerParameter* input_layer_param = net_param->add_layer(); + input_layer_param->set_type("Input"); + InputParameter* input_param = input_layer_param->mutable_input_param(); + + input_layer_param->add_top("c_0"); + input_param->add_shape()->CopyFrom(input_shapes[0]); + + input_layer_param->add_top("h_0"); + input_param->add_shape()->CopyFrom(input_shapes[1]); + + LayerParameter* cont_slice_param = net_param->add_layer(); + cont_slice_param->CopyFrom(slice_param); + cont_slice_param->set_name("cont_slice"); + cont_slice_param->add_bottom("cont"); + cont_slice_param->mutable_slice_param()->set_axis(0); + + // Add layer to transform all timesteps of x to the hidden state dimension. + // W_xc_x = W_xc * x + b_c + { + LayerParameter* x_transform_param = net_param->add_layer(); + x_transform_param->CopyFrom(biased_hidden_param); + x_transform_param->set_name("x_transform"); + x_transform_param->add_param()->set_name("W_xc"); + x_transform_param->add_param()->set_name("b_c"); + x_transform_param->add_bottom("x"); + x_transform_param->add_top("W_xc_x"); + x_transform_param->add_propagate_down(true); + } + + if (this->static_input_) { + // Add layer to transform x_static to the gate dimension. + // W_xc_x_static = W_xc_static * x_static + LayerParameter* x_static_transform_param = net_param->add_layer(); + x_static_transform_param->CopyFrom(hidden_param); + x_static_transform_param->mutable_inner_product_param()->set_axis(1); + x_static_transform_param->set_name("W_xc_x_static"); + x_static_transform_param->add_param()->set_name("W_xc_static"); + x_static_transform_param->add_bottom("x_static"); + x_static_transform_param->add_top("W_xc_x_static_preshape"); + x_static_transform_param->add_propagate_down(true); + + LayerParameter* reshape_param = net_param->add_layer(); + reshape_param->set_type("Reshape"); + BlobShape* new_shape = + reshape_param->mutable_reshape_param()->mutable_shape(); + new_shape->add_dim(1); // One timestep. + // Should infer this->N as the dimension so we can reshape on batch size. + new_shape->add_dim(-1); + new_shape->add_dim( + x_static_transform_param->inner_product_param().num_output()); + reshape_param->set_name("W_xc_x_static_reshape"); + reshape_param->add_bottom("W_xc_x_static_preshape"); + reshape_param->add_top("W_xc_x_static"); + } + + LayerParameter* x_slice_param = net_param->add_layer(); + x_slice_param->CopyFrom(slice_param); + x_slice_param->add_bottom("W_xc_x"); + x_slice_param->set_name("W_xc_x_slice"); + + LayerParameter output_concat_layer; + output_concat_layer.set_name("h_concat"); + output_concat_layer.set_type("Concat"); + output_concat_layer.add_top("h"); + output_concat_layer.mutable_concat_param()->set_axis(0); + + for (int t = 1; t <= this->T_; ++t) { + string tm1s = format_int(t - 1); + string ts = format_int(t); + + cont_slice_param->add_top("cont_" + ts); + x_slice_param->add_top("W_xc_x_" + ts); + + // Add layers to flush the hidden state when beginning a new + // sequence, as indicated by cont_t. + // h_conted_{t-1} := cont_t * h_{t-1} + // + // Normally, cont_t is binary (i.e., 0 or 1), so: + // h_conted_{t-1} := h_{t-1} if cont_t == 1 + // 0 otherwise + { + LayerParameter* cont_h_param = net_param->add_layer(); + cont_h_param->CopyFrom(scale_param); + cont_h_param->set_name("h_conted_" + tm1s); + cont_h_param->add_bottom("h_" + tm1s); + cont_h_param->add_bottom("cont_" + ts); + cont_h_param->add_top("h_conted_" + tm1s); + } + + // Add layer to compute + // W_hc_h_{t-1} := W_hc * h_conted_{t-1} + { + LayerParameter* w_param = net_param->add_layer(); + w_param->CopyFrom(hidden_param); + w_param->set_name("transform_" + ts); + w_param->add_param()->set_name("W_hc"); + w_param->add_bottom("h_conted_" + tm1s); + w_param->add_top("W_hc_h_" + tm1s); + w_param->mutable_inner_product_param()->set_axis(2); + } + + // Add the outputs of the linear transformations to compute the gate input. + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + // = W_hc_h_{t-1} + W_xc_x_t + b_c + { + LayerParameter* input_sum_layer = net_param->add_layer(); + input_sum_layer->CopyFrom(sum_param); + input_sum_layer->set_name("gate_input_" + ts); + input_sum_layer->add_bottom("W_hc_h_" + tm1s); + input_sum_layer->add_bottom("W_xc_x_" + ts); + if (this->static_input_) { + input_sum_layer->add_bottom("W_xc_x_static"); + } + input_sum_layer->add_top("gate_input_" + ts); + } + + // Add LSTMUnit layer to compute the cell & hidden vectors c_t and h_t. + // Inputs: c_{t-1}, gate_input_t = (i_t, f_t, o_t, g_t), cont_t + // Outputs: c_t, h_t + // [ i_t' ] + // [ f_t' ] := gate_input_t + // [ o_t' ] + // [ g_t' ] + // i_t := \sigmoid[i_t'] + // f_t := \sigmoid[f_t'] + // o_t := \sigmoid[o_t'] + // g_t := \tanh[g_t'] + // c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t) + // h_t := o_t .* \tanh[c_t] + { + LayerParameter* lstm_unit_param = net_param->add_layer(); + lstm_unit_param->set_type("LSTMUnit"); + lstm_unit_param->add_bottom("c_" + tm1s); + lstm_unit_param->add_bottom("gate_input_" + ts); + lstm_unit_param->add_bottom("cont_" + ts); + lstm_unit_param->add_top("c_" + ts); + lstm_unit_param->add_top("h_" + ts); + lstm_unit_param->set_name("unit_" + ts); + } + output_concat_layer.add_bottom("h_" + ts); + } // for (int t = 1; t <= this->T_; ++t) + + { + LayerParameter* c_T_copy_param = net_param->add_layer(); + c_T_copy_param->CopyFrom(split_param); + c_T_copy_param->add_bottom("c_" + format_int(this->T_)); + c_T_copy_param->add_top("c_T"); + } + net_param->add_layer()->CopyFrom(output_concat_layer); +} + +INSTANTIATE_CLASS(LSTMLayer); +REGISTER_LAYER_CLASS(LSTM); + +} // namespace caffe diff --git a/src/caffe/layers/lstm_unit_layer.cpp b/src/caffe/layers/lstm_unit_layer.cpp new file mode 100644 index 00000000000..277c031ad15 --- /dev/null +++ b/src/caffe/layers/lstm_unit_layer.cpp @@ -0,0 +1,131 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/layers/lstm_layer.hpp" + +namespace caffe { + +template +inline Dtype sigmoid(Dtype x) { + return 1. / (1. + exp(-x)); +} + +template +inline Dtype tanh(Dtype x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +void LSTMUnitLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + const int num_instances = bottom[0]->shape(1); + for (int i = 0; i < bottom.size(); ++i) { + if (i == 2) { + CHECK_EQ(2, bottom[i]->num_axes()); + } else { + CHECK_EQ(3, bottom[i]->num_axes()); + } + CHECK_EQ(1, bottom[i]->shape(0)); + CHECK_EQ(num_instances, bottom[i]->shape(1)); + } + hidden_dim_ = bottom[0]->shape(2); + CHECK_EQ(num_instances, bottom[1]->shape(1)); + CHECK_EQ(4 * hidden_dim_, bottom[1]->shape(2)); + top[0]->ReshapeLike(*bottom[0]); + top[1]->ReshapeLike(*bottom[0]); + X_acts_.ReshapeLike(*bottom[1]); +} + +template +void LSTMUnitLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const int num = bottom[0]->shape(1); + const int x_dim = hidden_dim_ * 4; + const Dtype* C_prev = bottom[0]->cpu_data(); + const Dtype* X = bottom[1]->cpu_data(); + const Dtype* cont = bottom[2]->cpu_data(); + Dtype* C = top[0]->mutable_cpu_data(); + Dtype* H = top[1]->mutable_cpu_data(); + for (int n = 0; n < num; ++n) { + for (int d = 0; d < hidden_dim_; ++d) { + const Dtype i = sigmoid(X[d]); + const Dtype f = (*cont == 0) ? 0 : + (*cont * sigmoid(X[1 * hidden_dim_ + d])); + const Dtype o = sigmoid(X[2 * hidden_dim_ + d]); + const Dtype g = tanh(X[3 * hidden_dim_ + d]); + const Dtype c_prev = C_prev[d]; + const Dtype c = f * c_prev + i * g; + C[d] = c; + const Dtype tanh_c = tanh(c); + H[d] = o * tanh_c; + } + C_prev += hidden_dim_; + X += x_dim; + C += hidden_dim_; + H += hidden_dim_; + ++cont; + } +} + +template +void LSTMUnitLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + CHECK(!propagate_down[2]) << "Cannot backpropagate to sequence indicators."; + if (!propagate_down[0] && !propagate_down[1]) { return; } + + const int num = bottom[0]->shape(1); + const int x_dim = hidden_dim_ * 4; + const Dtype* C_prev = bottom[0]->cpu_data(); + const Dtype* X = bottom[1]->cpu_data(); + const Dtype* cont = bottom[2]->cpu_data(); + const Dtype* C = top[0]->cpu_data(); + const Dtype* H = top[1]->cpu_data(); + const Dtype* C_diff = top[0]->cpu_diff(); + const Dtype* H_diff = top[1]->cpu_diff(); + Dtype* C_prev_diff = bottom[0]->mutable_cpu_diff(); + Dtype* X_diff = bottom[1]->mutable_cpu_diff(); + for (int n = 0; n < num; ++n) { + for (int d = 0; d < hidden_dim_; ++d) { + const Dtype i = sigmoid(X[d]); + const Dtype f = (*cont == 0) ? 0 : + (*cont * sigmoid(X[1 * hidden_dim_ + d])); + const Dtype o = sigmoid(X[2 * hidden_dim_ + d]); + const Dtype g = tanh(X[3 * hidden_dim_ + d]); + const Dtype c_prev = C_prev[d]; + const Dtype c = C[d]; + const Dtype tanh_c = tanh(c); + Dtype* c_prev_diff = C_prev_diff + d; + Dtype* i_diff = X_diff + d; + Dtype* f_diff = X_diff + 1 * hidden_dim_ + d; + Dtype* o_diff = X_diff + 2 * hidden_dim_ + d; + Dtype* g_diff = X_diff + 3 * hidden_dim_ + d; + const Dtype c_term_diff = + C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[d] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } + C_prev += hidden_dim_; + X += x_dim; + C += hidden_dim_; + H += hidden_dim_; + C_diff += hidden_dim_; + H_diff += hidden_dim_; + X_diff += x_dim; + C_prev_diff += hidden_dim_; + ++cont; + } +} + +#ifdef CPU_ONLY +STUB_GPU(LSTMUnitLayer); +#endif + +INSTANTIATE_CLASS(LSTMUnitLayer); +REGISTER_LAYER_CLASS(LSTMUnit); + +} // namespace caffe diff --git a/src/caffe/layers/lstm_unit_layer.cu b/src/caffe/layers/lstm_unit_layer.cu new file mode 100644 index 00000000000..15bb451d9e0 --- /dev/null +++ b/src/caffe/layers/lstm_unit_layer.cu @@ -0,0 +1,154 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/layers/lstm_layer.hpp" + +namespace caffe { + +template +__device__ Dtype sigmoid(const Dtype x) { + return Dtype(1) / (Dtype(1) + exp(-x)); +} + +template +__device__ Dtype tanh(const Dtype x) { + return Dtype(2) * sigmoid(Dtype(2) * x) - Dtype(1); +} + +template +__global__ void LSTMActsForward(const int nthreads, const int dim, + const Dtype* X, Dtype* X_acts) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int x_dim = 4 * dim; + const int d = index % x_dim; + if (d < 3 * dim) { + X_acts[index] = sigmoid(X[index]); + } else { + X_acts[index] = tanh(X[index]); + } + } +} + +template +__global__ void LSTMUnitForward(const int nthreads, const int dim, + const Dtype* C_prev, const Dtype* X, const Dtype* cont, + Dtype* C, Dtype* H) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + const Dtype* X_offset = X + 4 * dim * n; + const Dtype i = X_offset[d]; + const Dtype f = X_offset[1 * dim + d]; + const Dtype o = X_offset[2 * dim + d]; + const Dtype g = X_offset[3 * dim + d]; + const Dtype c_prev = C_prev[index]; + const Dtype c = cont[n] * f * c_prev + i * g; + C[index] = c; + const Dtype tanh_c = tanh(c); + H[index] = o * tanh_c; + } +} + +template +void LSTMUnitLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const int count = top[1]->count(); + const Dtype* C_prev = bottom[0]->gpu_data(); + const Dtype* X = bottom[1]->gpu_data(); + const Dtype* cont = bottom[2]->gpu_data(); + Dtype* X_acts = X_acts_.mutable_gpu_data(); + Dtype* C = top[0]->mutable_gpu_data(); + Dtype* H = top[1]->mutable_gpu_data(); + const int X_count = bottom[1]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + LSTMActsForward<<>>( + X_count, hidden_dim_, X, X_acts); + CUDA_POST_KERNEL_CHECK; + // NOLINT_NEXT_LINE(whitespace/operators) + LSTMUnitForward<<>>( + count, hidden_dim_, C_prev, X_acts, cont, C, H); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void LSTMUnitBackward(const int nthreads, const int dim, + const Dtype* C_prev, const Dtype* X, const Dtype* C, const Dtype* H, + const Dtype* cont, const Dtype* C_diff, const Dtype* H_diff, + Dtype* C_prev_diff, Dtype* X_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + const Dtype* X_offset = X + 4 * dim * n; + const Dtype i = X_offset[d]; + const Dtype f = X_offset[1 * dim + d]; + const Dtype o = X_offset[2 * dim + d]; + const Dtype g = X_offset[3 * dim + d]; + const Dtype c_prev = C_prev[index]; + const Dtype c = C[index]; + const Dtype tanh_c = tanh(c); + Dtype* c_prev_diff = C_prev_diff + index; + Dtype* X_diff_offset = X_diff + 4 * dim * n; + Dtype* i_diff = X_diff_offset + d; + Dtype* f_diff = X_diff_offset + 1 * dim + d; + Dtype* o_diff = X_diff_offset + 2 * dim + d; + Dtype* g_diff = X_diff_offset + 3 * dim + d; + const Dtype c_term_diff = + C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c); + const Dtype cont_n = cont[n]; + *c_prev_diff = cont_n * c_term_diff * f; + *i_diff = c_term_diff * g; + *f_diff = cont_n * c_term_diff * c_prev; + *o_diff = H_diff[index] * tanh_c; + *g_diff = c_term_diff * i; + } +} + +template +__global__ void LSTMActsBackward(const int nthreads, const int dim, + const Dtype* X_acts, const Dtype* X_acts_diff, Dtype* X_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int x_dim = 4 * dim; + const int d = index % x_dim; + const Dtype X_act = X_acts[index]; + if (d < 3 * dim) { + X_diff[index] = X_acts_diff[index] * X_act * (Dtype(1) - X_act); + } else { + X_diff[index] = X_acts_diff[index] * (Dtype(1) - X_act * X_act); + } + } +} + +template +void LSTMUnitLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + CHECK(!propagate_down[2]) << "Cannot backpropagate to sequence indicators."; + if (!propagate_down[0] && !propagate_down[1]) { return; } + + const int count = top[1]->count(); + const Dtype* C_prev = bottom[0]->gpu_data(); + const Dtype* X_acts = X_acts_.gpu_data(); + const Dtype* cont = bottom[2]->gpu_data(); + const Dtype* C = top[0]->gpu_data(); + const Dtype* H = top[1]->gpu_data(); + const Dtype* C_diff = top[0]->gpu_diff(); + const Dtype* H_diff = top[1]->gpu_diff(); + Dtype* C_prev_diff = bottom[0]->mutable_gpu_diff(); + Dtype* X_acts_diff = X_acts_.mutable_gpu_diff(); + LSTMUnitBackward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>(count, hidden_dim_, + C_prev, X_acts, C, H, cont, C_diff, H_diff, C_prev_diff, X_acts_diff); + CUDA_POST_KERNEL_CHECK; + const int X_count = bottom[1]->count(); + Dtype* X_diff = bottom[1]->mutable_gpu_diff(); + LSTMActsBackward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + X_count, hidden_dim_, X_acts, X_acts_diff, X_diff); + CUDA_POST_KERNEL_CHECK; +} + +INSTANTIATE_LAYER_GPU_FUNCS(LSTMUnitLayer); + +} // namespace caffe diff --git a/src/caffe/test/test_lstm_layer.cpp b/src/caffe/test/test_lstm_layer.cpp new file mode 100644 index 00000000000..51905baafac --- /dev/null +++ b/src/caffe/test/test_lstm_layer.cpp @@ -0,0 +1,288 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/lstm_layer.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class LSTMLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + LSTMLayerTest() : num_output_(7) { + blob_bottom_vec_.push_back(&blob_bottom_); + blob_bottom_vec_.push_back(&blob_bottom_cont_); + blob_top_vec_.push_back(&blob_top_); + unit_blob_bottom_vec_.push_back(&unit_blob_bottom_c_prev_); + unit_blob_bottom_vec_.push_back(&unit_blob_bottom_x_); + unit_blob_bottom_vec_.push_back(&unit_blob_bottom_cont_); + unit_blob_top_vec_.push_back(&unit_blob_top_c_); + unit_blob_top_vec_.push_back(&unit_blob_top_h_); + + ReshapeBlobs(1, 3); + + layer_param_.mutable_recurrent_param()->set_num_output(num_output_); + FillerParameter* weight_filler = + layer_param_.mutable_recurrent_param()->mutable_weight_filler(); + weight_filler->set_type("gaussian"); + weight_filler->set_std(0.2); + FillerParameter* bias_filler = + layer_param_.mutable_recurrent_param()->mutable_bias_filler(); + bias_filler->set_type("gaussian"); + bias_filler->set_std(0.1); + + layer_param_.set_phase(TEST); + } + + void ReshapeBlobs(int num_timesteps, int num_instances) { + blob_bottom_.Reshape(num_timesteps, num_instances, 3, 2); + blob_bottom_static_.Reshape(num_instances, 2, 3, 4); + vector shape(2); + shape[0] = num_timesteps; + shape[1] = num_instances; + blob_bottom_cont_.Reshape(shape); + shape.push_back(num_output_); + + shape[0] = 1; shape[1] = num_instances; shape[2] = 4 * num_output_; + unit_blob_bottom_x_.Reshape(shape); + shape[0] = 1; shape[1] = num_instances; shape[2] = num_output_; + unit_blob_bottom_c_prev_.Reshape(shape); + shape.resize(2); + shape[0] = 1; shape[1] = num_instances; + unit_blob_bottom_cont_.Reshape(shape); + + FillerParameter filler_param; + filler_param.set_min(-1); + filler_param.set_max(1); + UniformFiller filler(filler_param); + filler.Fill(&blob_bottom_); + filler.Fill(&unit_blob_bottom_c_prev_); + filler.Fill(&unit_blob_bottom_x_); + } + + int num_output_; + LayerParameter layer_param_; + Blob blob_bottom_; + Blob blob_bottom_cont_; + Blob blob_bottom_static_; + Blob blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + + Blob unit_blob_bottom_cont_; + Blob unit_blob_bottom_c_prev_; + Blob unit_blob_bottom_x_; + Blob unit_blob_top_c_; + Blob unit_blob_top_h_; + vector*> unit_blob_bottom_vec_; + vector*> unit_blob_top_vec_; +}; + +TYPED_TEST_CASE(LSTMLayerTest, TestDtypesAndDevices); + +TYPED_TEST(LSTMLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + LSTMLayer layer(this->layer_param_); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + vector expected_top_shape = this->blob_bottom_.shape(); + expected_top_shape.resize(3); + expected_top_shape[2] = this->num_output_; + EXPECT_TRUE(this->blob_top_.shape() == expected_top_shape); +} + +TYPED_TEST(LSTMLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + const int kNumTimesteps = 3; + const int num = this->blob_bottom_.shape(1); + this->ReshapeBlobs(kNumTimesteps, num); + + // Fill the cont blob with <0, 1, 1, ..., 1>, + // indicating a sequence that begins at the first timestep + // then continues for the rest of the sequence. + for (int t = 0; t < kNumTimesteps; ++t) { + for (int n = 0; n < num; ++n) { + this->blob_bottom_cont_.mutable_cpu_data()[t * num + n] = t > 0; + } + } + + // Process the full sequence in a single batch. + FillerParameter filler_param; + filler_param.set_mean(0); + filler_param.set_std(1); + GaussianFiller sequence_filler(filler_param); + Caffe::set_random_seed(1); + sequence_filler.Fill(&this->blob_bottom_); + shared_ptr > layer(new LSTMLayer(this->layer_param_)); + Caffe::set_random_seed(1701); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + LOG(INFO) << "Calling forward for full sequence LSTM"; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Copy the inputs and outputs to reuse/check them later. + Blob bottom_copy(this->blob_bottom_.shape()); + bottom_copy.CopyFrom(this->blob_bottom_); + Blob top_copy(this->blob_top_.shape()); + top_copy.CopyFrom(this->blob_top_); + + // Process the batch one timestep at a time; + // check that we get the same result. + this->ReshapeBlobs(1, num); + layer.reset(new LSTMLayer(this->layer_param_)); + Caffe::set_random_seed(1701); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const int bottom_count = this->blob_bottom_.count(); + const int top_count = this->blob_top_.count(); + const Dtype kEpsilon = 1e-5; + for (int t = 0; t < kNumTimesteps; ++t) { + caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count, + this->blob_bottom_.mutable_cpu_data()); + for (int n = 0; n < num; ++n) { + this->blob_bottom_cont_.mutable_cpu_data()[n] = t > 0; + } + LOG(INFO) << "Calling forward for LSTM timestep " << t; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < top_count; ++i) { + ASSERT_LT(t * top_count + i, top_copy.count()); + EXPECT_NEAR(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i], kEpsilon) + << "t = " << t << "; i = " << i; + } + } + + // Process the batch one timestep at a time with all cont blobs set to 0. + // Check that we get a different result, except in the first timestep. + Caffe::set_random_seed(1701); + layer.reset(new LSTMLayer(this->layer_param_)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + for (int t = 0; t < kNumTimesteps; ++t) { + caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count, + this->blob_bottom_.mutable_cpu_data()); + for (int n = 0; n < num; ++n) { + this->blob_bottom_cont_.mutable_cpu_data()[n] = 0; + } + LOG(INFO) << "Calling forward for LSTM timestep " << t; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < top_count; ++i) { + if (t == 0) { + EXPECT_NEAR(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i], kEpsilon) + << "t = " << t << "; i = " << i; + } else { + EXPECT_NE(this->blob_top_.cpu_data()[i], + top_copy.cpu_data()[t * top_count + i]) + << "t = " << t << "; i = " << i; + } + } + } +} + +TYPED_TEST(LSTMLayerTest, TestLSTMUnitSetUp) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LSTMUnitLayer layer(layer_param); + layer.SetUp(this->unit_blob_bottom_vec_, this->unit_blob_top_vec_); + const int num_axes = this->unit_blob_bottom_c_prev_.num_axes(); + ASSERT_EQ(num_axes, this->unit_blob_top_c_.num_axes()); + ASSERT_EQ(num_axes, this->unit_blob_top_h_.num_axes()); + for (int i = 0; i < num_axes; ++i) { + EXPECT_EQ(this->unit_blob_bottom_c_prev_.shape(i), + this->unit_blob_top_c_.shape(i)); + EXPECT_EQ(this->unit_blob_bottom_c_prev_.shape(i), + this->unit_blob_top_h_.shape(i)); + } +} + +TYPED_TEST(LSTMLayerTest, TestLSTMUnitGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LSTMUnitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + Dtype* cont_data = this->blob_bottom_cont_.mutable_cpu_data(); + cont_data[0] = 0; + cont_data[1] = 0; + cont_data[2] = 0; + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 1); +} + +TYPED_TEST(LSTMLayerTest, TestLSTMUnitGradientNonZeroCont) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LSTMUnitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + Dtype* cont_data = this->blob_bottom_cont_.mutable_cpu_data(); + cont_data[0] = 1; + cont_data[1] = 0; + cont_data[2] = 1; + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_, + this->unit_blob_top_vec_, 1); +} + +TYPED_TEST(LSTMLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LSTMLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(LSTMLayerTest, TestGradientNonZeroCont) { + typedef typename TypeParam::Dtype Dtype; + LSTMLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) { + this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(LSTMLayerTest, TestGradientNonZeroContBufferSize2) { + typedef typename TypeParam::Dtype Dtype; + this->ReshapeBlobs(2, 2); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(&this->blob_bottom_); + LSTMLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) { + this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(LSTMLayerTest, TestGradientNonZeroContBufferSize2WithStaticInput) { + typedef typename TypeParam::Dtype Dtype; + this->ReshapeBlobs(2, 2); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(&this->blob_bottom_); + filler.Fill(&this->blob_bottom_static_); + this->blob_bottom_vec_.push_back(&this->blob_bottom_static_); + LSTMLayer layer(this->layer_param_); + GradientChecker checker(1e-2, 1e-3); + for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) { + this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 2); +} + + +} // namespace caffe