Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add maxout layer, including interface and unittest #229

Merged
merged 8 commits into from
Oct 24, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/ui/api/trainer_config_helpers/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ img_pool_layer
:members: img_pool_layer
:noindex:

maxout_layer
------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: maxout_layer
:noindex:

Norm Layer
==========

Expand Down
32 changes: 31 additions & 1 deletion paddle/cuda/include/hl_cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ extern void hl_avgpool_forward(
* @brief Maximum pool backward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] outGrad input data.
* @param[in] outGrad output grad data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
Expand Down Expand Up @@ -240,4 +240,34 @@ extern void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta);

/**
* @brief MaxOut forward.
*
* @param[in] inData input data.
* @param[out] outData output data.
* @param[out] idData output maxId.
* @param[in] batchSize batchSize.
* @param[in] size number of channels * image height * image width.
* @param[in] featLen feature length = image height * image width.
* @param[in] groups number of groups.
*/
extern void hl_maxout_forward(
const real* inData, real* outData, int* idData,
size_t batchSize, size_t size, size_t featLen, size_t groups);

/**
* @brief MaxOut backward.
*
* @param[out] inGrad input grad data.
* @param[in] outGrad output grad data.
* @param[in] idData output maxId.
* @param[in] batchSize batchSize.
* @param[in] size number of channels * image height * image width.
* @param[in] featLen feature length = image height * image width.
* @param[in] groups number of groups.
*/
extern void hl_maxout_backward(
real* inGrad, const real* outGrad, const int* idData,
size_t batchSize, size_t size, size_t featLen, size_t groups);

#endif /* HL_CNN_H_ */
8 changes: 8 additions & 0 deletions paddle/cuda/include/stub/hl_cnn_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,12 @@ inline void hl_CMRNorm_backward(
size_t channels, size_t height, size_t width, size_t sizeX,
real alpha, real beta) {}

inline void hl_maxout_forward(
const real* inData, real* outData, int* idData,
size_t batchSize, size_t size, size_t featLen, size_t group) {}

inline void hl_maxout_backward(
real* inGrad, const real* outGrad, const int* idData,
size_t batchSize, size_t size, size_t featLen, size_t group) {}

#endif // HL_CNN_STUB_H_
59 changes: 59 additions & 0 deletions paddle/cuda/src/hl_cuda_cnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,62 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
height, width, sizeX, alpha, beta, inDiff);
CHECK_SYNC("hl_CMRNorm_backward");
}

__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
real * outData, int* idData,
size_t size, size_t featLen, size_t groups) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if(index < nthreads) {
size_t batch_idx = index / size;
size_t i = index % size;
size_t channel_idx = i / featLen;
size_t feat_idx = i % featLen;
size_t data_idx = (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
real max = inData[data_idx];
int maxId = 0;
for (size_t g = 1; g < groups; ++g) {
real tmp = inData[data_idx + g * featLen];
if (tmp > max) {
max = tmp;
maxId = g;
}
}
outData[index] = max;
idData[index] = maxId;
}
}

void hl_maxout_forward(const real* inData, real* outData,
int* idData, size_t batchSize, size_t size,
size_t featLen, size_t groups) {
int num_kernels = size * batchSize;
int blocks = (num_kernels + 1024 - 1) / 1024;
maxoutFpCompute<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
num_kernels, inData, outData, idData, size, featLen, groups);
CHECK_SYNC("hl_maxout_forward failed");
}

__global__ void maxoutBpCompute(size_t nthreads, real* inGrad,
const real* outGrad, const int* idData,
size_t size, size_t featLen, size_t groups) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if(index < nthreads) {
size_t batch_idx = index / size;
size_t i = index % size;
size_t channel_idx = i / featLen;
size_t feat_idx = i % featLen;
size_t newIndex = batch_idx * size;
size_t gradIdx = (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
(inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i];
}
}

void hl_maxout_backward(real* inGrad, const real* outGrad,
const int* idData, size_t batchSize, size_t size,
size_t featLen, size_t groups) {
int num_kernels = size * batchSize;
int blocks = (num_kernels + 1024 - 1) / 1024;
maxoutBpCompute<<< blocks, 1024, 0, STREAM_DEFAULT >>>(
num_kernels, inGrad, outGrad, idData, size, featLen, groups);
CHECK_SYNC("hl_maxout_backward failed");
}
87 changes: 87 additions & 0 deletions paddle/gserver/layers/MaxOutLayer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MaxOutLayer.h"
#include "hl_gpu.h"
#include "hl_cnn.h"

namespace paddle {

REGISTER_LAYER(maxout, MaxOutLayer);

size_t MaxOutLayer::getSize() {
const MaxOutConfig& maxoutConf = config_.inputs(0).maxout_conf();
imgSizeH_ = inputLayers_[0]->getOutput().getFrameHeight();
imgSizeW_ = inputLayers_[0]->getOutput().getFrameWidth();
if (imgSizeH_ == 0) {
imgSizeH_ = maxoutConf.img_size_y();
}
if (imgSizeW_ == 0) {
imgSizeW_ = maxoutConf.img_size_x();
}

featLen_ = imgSizeH_ * imgSizeW_;
size_t layerSize = featLen_ * outputChannels_;

getOutput().setFrameHeight(imgSizeH_);
getOutput().setFrameWidth(imgSizeW_);

return layerSize;
}

bool MaxOutLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);

/* the size of inputs for maxout-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1UL);

const MaxOutConfig& conf = config_.inputs(0).maxout_conf();
groups_ = conf.groups();
channels_ = conf.channels();
CHECK_EQ(channels_ % groups_, 0UL);
outputChannels_ = channels_ / groups_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check CHECK_EQ(channels%groups_, 0);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add check


return true;
}

void MaxOutLayer::forward(PassType passType) {
Layer::forward(passType);

/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one column */
size_t batchSize = getInput(0).getBatchSize();
size_t size = getSize();
resetOutput(batchSize, size);
MatrixPtr inputV = getInputValue(0);
MatrixPtr outV = getOutputValue();

IVector::resizeOrCreate(maxoutId_, size * batchSize, useGpu_);
outV->maxoutForward(*inputV, *maxoutId_, outputChannels_, groups_);
}

void MaxOutLayer::backward(const UpdateCallback& callback) {
(void)callback;

/* Do derivation */
MatrixPtr inputG = getInputGrad(0);
MatrixPtr outG = getOutputGrad();

if (inputG) {
inputG->maxoutBackward(*outG, *maxoutId_, outputChannels_, groups_);
}
}

} // namespace paddle
54 changes: 54 additions & 0 deletions paddle/gserver/layers/MaxOutLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "Layer.h"
#include "paddle/math/Matrix.h"

namespace paddle {

/**
* A layer to do max out on conv layer output.
* Input: output of a conv layer.
* Output: feature map size same as input. Channel is (input channel) / groups.
* So the num of channels should be able to devided by groups.
*
* The config file api is maxout_layer.
*/

class MaxOutLayer : public Layer {
protected:
size_t groups_;
size_t imgSizeH_, imgSizeW_;
/// outputChannels_ = channels_ / groups_
size_t channels_, outputChannels_;
/// feature length = imgSizeH_ * imgSizeW_
size_t featLen_;
IVectorPtr maxoutId_;

public:
/// return imgSizeH_ * imgSizeW_ * outputChannels_;
size_t getSize();

explicit MaxOutLayer(const LayerConfig& config) : Layer(config) {}
virtual ~MaxOutLayer() {}

bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);

void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
};

} // namespace paddle
29 changes: 27 additions & 2 deletions paddle/gserver/tests/rnn_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@

from paddle.trainer.PyDataProvider2 import *

# Note that each config should has an independent provider
# in current design of PyDataProvider2.
#######################################################
data = [
[[[1, 3, 2], [4, 5, 2]], 0],
[[[0, 2], [2, 5], [0, 1, 2]], 1],
]


# Used for sequence_nest_rnn.conf
@provider(input_types=[integer_value_sub_sequence(10),
integer_value(3)],
should_shuffle=False)
def process_subseq(settings, file_name):
for d in data:
yield d


# Used for sequence_rnn.conf
@provider(input_types=[integer_value_sequence(10),
integer_value(3)],
should_shuffle=False)
Expand All @@ -38,11 +41,32 @@ def process_seq(settings, file_name):
seq += subseq
yield seq, d[1]

# Used for sequence_nest_rnn_multi_input.conf
@provider(input_types=[integer_value_sub_sequence(10),
integer_value(3)],
should_shuffle=False)
def process_subseq2(settings, file_name):
for d in data:
yield d

# Used for sequence_rnn_multi_input.conf
@provider(input_types=[integer_value_sequence(10),
integer_value(3)],
should_shuffle=False)
def process_seq2(settings, file_name):
for d in data:
seq = []
for subseq in d[0]:
seq += subseq
yield seq, d[1]

###########################################################
data2 = [
[[[1, 2], [4, 5, 2]], [[5, 4, 1], [3, 1]] ,0],
[[[0, 2], [2, 5], [0, 1, 2]],[[1, 5], [4], [2, 3, 6, 1]], 1],
]

# Used for sequence_nest_rnn_multi_unequalength_inputs.conf
@provider(input_types=[integer_value_sub_sequence(10),
integer_value_sub_sequence(10),
integer_value(2)],
Expand All @@ -52,6 +76,7 @@ def process_unequalength_subseq(settings, file_name):
yield d


# Used for sequence_rnn_multi_unequalength_inputs.conf
@provider(input_types=[integer_value_sequence(10),
integer_value_sequence(10),
integer_value(2)],
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/tests/sequence_nest_rnn_multi_input.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import *
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
test_list=None,
module='rnn_data_provider',
obj='process_subseq')
obj='process_subseq2')


settings(batch_size=2, learning_rate=0.01)
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/tests/sequence_rnn_multi_input.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import *
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
test_list=None,
module='rnn_data_provider',
obj='process_seq')
obj='process_seq2')


settings(batch_size=2, learning_rate=0.01)
Expand Down
18 changes: 18 additions & 0 deletions paddle/gserver/tests/test_LayerGrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,24 @@ TEST(Layer, blockExpandLayer) {
}
}

TEST(Layer, maxoutLayer) {
TestConfig config;
config.biasSize = 0;
config.layerConfig.set_type("maxout");

config.inputDefs.push_back({INPUT_DATA, "layer_0", 4096, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
MaxOutConfig* maxout = input->mutable_maxout_conf();

maxout->set_img_size_x(32);
maxout->set_img_size_y(32);
maxout->set_channels(4);
maxout->set_groups(2);

for (auto useGpu : {false, true}) {
testLayerGrad(config, "maxout", 10, false, useGpu);
}
}
void testFcLayer(string format, size_t nnz) {
TestConfig config;
config.biasSize = 4096;
Expand Down
Loading