Skip to content

Commit

Permalink
Merge pull request BVLC#4 from justinjfu/justinf_rnn_nonlin
Browse files Browse the repository at this point in the history
Add nonlinearity option to RNNLayer
  • Loading branch information
justinjfu committed Apr 22, 2015
2 parents 56eb50a + 5f6dc75 commit eb2eb1b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/caffe/layers/rnn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
sum_param.mutable_eltwise_param()->set_operation(
EltwiseParameter_EltwiseOp_SUM);

LayerParameter tanh_param;
tanh_param.set_type("TanH");
LayerParameter output_nonlinearity_param;
output_nonlinearity_param.set_type(this->layer_param_.rnn_param().output_nonlinearity());

LayerParameter recurrent_nonlinearity_param;
recurrent_nonlinearity_param.set_type(this->layer_param_.rnn_param().recurrent_nonlinearity());

LayerParameter slice_param;
slice_param.set_type("Slice");
Expand Down Expand Up @@ -176,7 +179,7 @@ void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
}
{
LayerParameter* h_neuron_param = net_param->add_layer();
h_neuron_param->CopyFrom(tanh_param);
h_neuron_param->CopyFrom(recurrent_nonlinearity_param);
h_neuron_param->set_name("h_neuron_" + ts);
h_neuron_param->add_bottom("h_neuron_input_" + ts);
h_neuron_param->add_top("h_" + ts);
Expand All @@ -200,7 +203,7 @@ void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
// = \tanh( W_ho_h_t )
{
LayerParameter* o_neuron_param = net_param->add_layer();
o_neuron_param->CopyFrom(tanh_param);
o_neuron_param->CopyFrom(output_nonlinearity_param);
o_neuron_param->set_name("o_neuron_" + ts);
o_neuron_param->add_bottom("W_ho_h_" + ts);
o_neuron_param->add_top("o_" + ts);
Expand Down
9 changes: 9 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ message LayerParameter {
optional PowerParameter power_param = 122;
optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 133;
optional RNNParameter rnn_param = 136;
optional ReLUParameter relu_param = 123;
optional ReshapeParameter reshape_param = 132;
optional SigmoidParameter sigmoid_param = 124;
Expand Down Expand Up @@ -721,6 +722,14 @@ message RecurrentParameter {
optional bool debug_info = 4 [default = false];
}

// Message that stores parameters used by RNN Layer
message RNNParameter {
// Nonlinearity on output (o_t)
optional string output_nonlinearity = 1 [default = "TanH"];
// Nonlinearity on recurrent state (h_t)
optional string recurrent_nonlinearity = 2 [default = "TanH"];
}

// Message that stores parameters used by ReLULayer
message ReLUParameter {
// Allow non-zero slope for negative inputs to speed up optimization
Expand Down

0 comments on commit eb2eb1b

Please sign in to comment.