From 1e1a33b574c2dde11c8c112f550b39207673a07e Mon Sep 17 00:00:00 2001 From: Haonan Date: Tue, 6 Sep 2016 11:55:02 -0700 Subject: [PATCH 01/10] Argument concat for subsequence start positions Change-Id: Ia60c008a8c922f66e6b5e2ca3e488fc4625d6506 --- paddle/parameter/Argument.cpp | 46 ++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 93f86ceccffc5..8610a66452358 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -269,6 +269,9 @@ void Argument::concat(const std::vector& args, const std::vector& selectRows, const std::vector& seqStartPos, bool useGpu, hl_stream_t stream, PassType passType) { + CHECK(!subSequenceStartPositions) + << "undefined behavior for subsequence positions"; + size_t batchSize = selectRows.size(); auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src, int startRow, int pos, int size, @@ -347,9 +350,11 @@ void Argument::concat(const std::vector& args, bool useGpu, hl_stream_t stream, PassType passType) { int32_t batchSize = 0; int64_t numSequences = 0; + int64_t numSubSequences = 0; for (auto& arg : args) { batchSize += arg.getBatchSize(); numSequences += arg.getNumSequences(); + numSubSequences += arg.getNumSubSequences(); } auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src, @@ -393,8 +398,26 @@ void Argument::concat(const std::vector& args, bool useGpu, std::copy(src->begin(), src->end(), dst->begin() + startRow); }; + auto copySequencePos = [] + (ICpuGpuVectorPtr& dstSeq, const ICpuGpuVectorPtr& srcSeq, + int dstNumSequences, int srcNumSequences, + int& startSequences, int startRow) { + if (srcSeq) { + ICpuGpuVector::resizeOrCreate(dstSeq, dstNumSequences + 1, false); + const int* src = srcSeq->getData(false); + int* dest = dstSeq->getMutableData(false); + for (int i = 0; i < srcNumSequences + 1; ++i) { + dest[i + startSequences] = src[i] + startRow; + } + startSequences += srcNumSequences; + } else { + dstSeq.reset(); + } + }; + int startRow = 0; int startSequences = 0; + int startSubSequences = 0; dataId = args[0].dataId; for (auto& arg : args) { CHECK_EQ(arg.dataId, dataId) << "Arguments in concat should have" @@ -403,17 +426,18 @@ void Argument::concat(const std::vector& args, bool useGpu, copyArg(value, arg.value, startRow, useGpu); if (passType != PASS_TEST) copyArg(grad, arg.grad, startRow, useGpu); copyIds(ids, arg.ids, startRow, useGpu); - if (arg.sequenceStartPositions) { - ICpuGpuVector::resizeOrCreate(sequenceStartPositions, - numSequences + 1, - false); - const int* src = arg.sequenceStartPositions->getData(false); - int* dest = sequenceStartPositions->getMutableData(false); - for (int i = 0; i < arg.getNumSequences() + 1; ++i) { - dest[i + startSequences] = src[i] + startRow; - } - startSequences += arg.getNumSequences(); - } + copySequencePos(sequenceStartPositions, + arg.sequenceStartPositions, + numSequences, + arg.getNumSequences(), + startSequences, + startRow); + copySequencePos(subSequenceStartPositions, + arg.subSequenceStartPositions, + numSubSequences, + arg.getNumSubSequences(), + startSubSequences, + startRow); copyStrs(strs, arg.strs, startRow, useGpu); startRow += arg.getBatchSize(); } From 6a873f505f77b3fb183f791358413c27afa5b922 Mon Sep 17 00:00:00 2001 From: Haonan Date: Tue, 6 Sep 2016 11:44:57 -0700 Subject: [PATCH 02/10] fix the layer name error when stacking RecurrentGroups * the last layer in the stack already has all the suffixes --- python/paddle/trainer/config_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 53d8bb98f09e8..b26a63e7f3c1d 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -262,8 +262,8 @@ def SubModelEnd(name = None): def MakeLayerNameInParentSubmodel(name): suffix = "" - for submodel in g_submodel_stack[1:]: - suffix = "@" + submodel.name + suffix + if len(g_submodel_stack) > 1: + suffix = "@" + g_submodel_stack[-1].name return name + suffix def GetLayerBaseName(name): From fbfd24e6d949d60fd415dfd6bb393ca02eb27bcc Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 7 Sep 2016 11:53:40 +0800 Subject: [PATCH 03/10] revert CRFLayer, remove wrong gpu support Change-Id: I636cf13af5becb1168bc9749266b55580c46f6c9 --- paddle/gserver/layers/CRFLayer.cpp | 81 +++++-------------------- paddle/gserver/layers/CRFLayer.h | 5 -- paddle/gserver/tests/test_LayerGrad.cpp | 7 +-- 3 files changed, 17 insertions(+), 76 deletions(-) diff --git a/paddle/gserver/layers/CRFLayer.cpp b/paddle/gserver/layers/CRFLayer.cpp index df8a2b03142b8..fb0a0ddb3d45b 100644 --- a/paddle/gserver/layers/CRFLayer.cpp +++ b/paddle/gserver/layers/CRFLayer.cpp @@ -47,81 +47,40 @@ bool CRFLayer::init(const LayerMap& layerMap, // We don't need sequenceStartPositions because each sample of output_ is // for the cost of one sequence. setNeedSequenceInfo(false); - if (useGpu_) { - tmpCpuInput_.reserve(inputLayers_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - tmpCpuInput_.push_back(Argument()); - } - } + return true; } void CRFLayer::forward(PassType passType) { Layer::forward(passType); - if (useGpu_) { - for (size_t i = 0; i < inputLayers_.size(); i++) { - tmpCpuInput_[i].resizeAndCopyFrom(getInput(i), false, HPPL_STREAM_1); - } - VectorPtr cpuParameterValue; - VectorPtr cpuParameterGradient; - cpuParameterValue = - Vector::create(parameter_->getBuf(PARAMETER_VALUE)->getSize(), false); - cpuParameterValue-> - copyFrom(*parameter_->getBuf(PARAMETER_VALUE), HPPL_STREAM_1); - if (parameter_->getBuf(PARAMETER_GRADIENT)) { - cpuParameterGradient = - Vector::create(parameter_->getBuf(PARAMETER_GRADIENT)->getSize(), - false); - cpuParameterGradient-> - copyFrom(*parameter_->getBuf(PARAMETER_GRADIENT), HPPL_STREAM_1); - } else { - cpuParameterGradient = nullptr; - } - forwardImp(tmpCpuInput_[0], tmpCpuInput_[1], cpuParameterValue, - cpuParameterGradient); - parameter_->getBuf(PARAMETER_VALUE)->copyFrom(*cpuParameterValue, - HPPL_STREAM_1); - if (parameter_->getBuf(PARAMETER_GRADIENT)) { - parameter_->getBuf(PARAMETER_GRADIENT)->copyFrom(*cpuParameterGradient, - HPPL_STREAM_1); - } - } else { - forwardImp(getInput(0), getInput(1), parameter_->getBuf(PARAMETER_VALUE), - parameter_->getBuf(PARAMETER_GRADIENT)); - } -} -void CRFLayer::forwardImp(const Argument&output, - const Argument& label, - VectorPtr parameterValue, - VectorPtr parameterGradient) { + CHECK(!useGpu_) << "GPU is not supported"; + + const Argument& output = getInput(0); + const Argument& label = getInput(1); CHECK(label.sequenceStartPositions); CHECK(label.ids); int batchSize = output.getBatchSize(); size_t numSequences = label.sequenceStartPositions->getSize() - 1; resizeOutput(numSequences, 1); - std::vector out(numSequences); const int* starts = label.sequenceStartPositions->getData(false); CHECK_EQ(starts[numSequences], batchSize); - VectorPtr cpuParameterValue; - VectorPtr cpuParameterGradient; - for (size_t i = 0; i < numSequences; ++i) { if (i >= crfs_.size()) { crfs_.emplace_back(numClasses_, - parameterValue->getData(), - parameterGradient - ? parameterGradient->getData() + parameter_->getBuf(PARAMETER_VALUE)->getData(), + parameter_->getBuf(PARAMETER_GRADIENT) + ? parameter_->getBuf(PARAMETER_GRADIENT)->getData() : nullptr); } - out[i] = crfs_[i].forward( + output_.value->getData()[i] = crfs_[i].forward( output.value->getData() + numClasses_ * starts[i], label.ids->getData() + starts[i], starts[i + 1] - starts[i]); } - output_.value->copyFrom(out.data(), numSequences); + if (weightLayer_) { const MatrixPtr& weight = getInputValue(*weightLayer_); getOutputValue()->dotMul(*getOutputValue(), *weight); @@ -129,22 +88,8 @@ void CRFLayer::forwardImp(const Argument&output, } void CRFLayer::backward(const UpdateCallback &callback) { - (void)callback; - if (useGpu_) { - backwardImp(callback, tmpCpuInput_[0], tmpCpuInput_[1]); - const_cast(getInput(0)). - resizeAndCopyFrom(tmpCpuInput_[0], true, HPPL_STREAM_1); - const_cast(getInput(1)). - resizeAndCopyFrom(tmpCpuInput_[1], true, HPPL_STREAM_1); - - } else { - backwardImp(callback, getInput(0), getInput(1)); - } -} - -void CRFLayer::backwardImp(const UpdateCallback& callback, - const Argument&output, - const Argument& label) { + const Argument& output = getInput(0); + const Argument& label = getInput(1); const int* starts = label.sequenceStartPositions->getData(false); int numSequences = label.sequenceStartPositions->getSize() - 1; @@ -159,9 +104,11 @@ void CRFLayer::backwardImp(const UpdateCallback& callback, grad->mulScalar(weight); } } + if (coeff_ != real(1.0f)) { output.grad->mulScalar(coeff_); } + parameter_->incUpdate(callback); } diff --git a/paddle/gserver/layers/CRFLayer.h b/paddle/gserver/layers/CRFLayer.h index 5facb9b54818c..c6ba8e7c965a3 100644 --- a/paddle/gserver/layers/CRFLayer.h +++ b/paddle/gserver/layers/CRFLayer.h @@ -32,11 +32,7 @@ class CRFLayer : public Layer { explicit CRFLayer(const LayerConfig& config) : Layer(config) {} virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); virtual void forward(PassType passType); - void forwardImp(const Argument&output, const Argument& label, - VectorPtr parameterValue, VectorPtr parameterGradient); virtual void backward(const UpdateCallback& callback); - void backwardImp(const UpdateCallback& callback, const Argument&output, - const Argument& label); protected: size_t numClasses_; @@ -44,7 +40,6 @@ class CRFLayer : public Layer { std::vector crfs_; LayerPtr weightLayer_; // weight for each sequence real coeff_; // weight for the layer - std::vector tmpCpuInput_; }; } // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 7bb79ff5b702a..5c80eb546cfaf 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -179,10 +179,9 @@ TEST(Layer, CRFLayer) { config.layerConfig.add_inputs(); config.layerConfig.add_inputs(); - for (auto useGpu : {false, true}) { - testLayerGrad(config, "crf", 100, /* trans */ false, /* useGpu */ useGpu, - false /*useWeight*/, 0.03 /*epsilon*/); - } + // Not support GPU now + testLayerGrad(config, "crf", 100, /* trans */ false, /* useGpu */ false, + false /*useWeight*/, 0.03 /*epsilon*/); } TEST(Layer, CTCLayer) { From 721b09eee6f3870cfe7a9ec9b59ef6ba8d42e264 Mon Sep 17 00:00:00 2001 From: liuyuan04 Date: Wed, 7 Sep 2016 13:34:50 +0800 Subject: [PATCH 04/10] Update Jumbo package to 0.8.0b0. Change-Id: I0b8608feab8f6be5094e8981fc5f65cb401ed415 --- CMakeLists.txt | 2 +- paddle/CMakeLists.txt | 3 +++ paddle/{setup.py => setup.py.in} | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) rename paddle/{setup.py => setup.py.in} (93%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0d950e144ffdc..007f1f18bb655 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8) project(paddle CXX C) set(PADDLE_MAJOR_VERSION 0) set(PADDLE_MINOR_VERSION 8) -set(PADDLE_PATCH_VERSION 0b) +set(PADDLE_PATCH_VERSION 0b0) set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION}) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake") diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index c6fa7dc2b16e1..cae0f64400a7e 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -7,6 +7,9 @@ add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in + ${CMAKE_CURRENT_SOURCE_DIR}/setup.py) + if(WITH_PREDICT_SDK) add_subdirectory(predict) endif() diff --git a/paddle/setup.py b/paddle/setup.py.in similarity index 93% rename from paddle/setup.py rename to paddle/setup.py.in index fabe2a6b4c1d5..da86eb795dc58 100644 --- a/paddle/setup.py +++ b/paddle/setup.py.in @@ -35,11 +35,11 @@ pass setup(name="py_paddle", - version="0.8.0b", # TODO(yuyang18): Make this version same as CMake + version="@PADDLE_VERSION@", ext_modules=[ Extension('py_paddle._swig_paddle', # Build SWIG Extension. ['Paddle_wrap.cxx'], - extra_link_args=["-Xlinker", '-start-group'] + + extra_link_args=["-Xlinker", '-start-group'] + extra_links + ["-Xlinker", "-end-group"] ) ], From 7ad55a4e76f334e5b1f86eb45fff5abb74210de8 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Tue, 6 Sep 2016 23:08:19 -0700 Subject: [PATCH 05/10] Fix ThreadParameterUpdater The reference return type causes ThreadParameterUpdater.cpp:123 seg fault under gcc5.4. Change-Id: I7a1c155892722076a7cb48793b83d5ee525747d1 --- paddle/trainer/ThreadParameterUpdater.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/trainer/ThreadParameterUpdater.h b/paddle/trainer/ThreadParameterUpdater.h index f47d3b08c1677..d8a7a5dd4f12a 100644 --- a/paddle/trainer/ThreadParameterUpdater.h +++ b/paddle/trainer/ThreadParameterUpdater.h @@ -79,7 +79,7 @@ class SgdThreadUpdater : public ParameterUpdater { // The update function for after update operations, such as averager. void threadTraverse(const ParameterOptimizer::TraverseCallback& callback, int tid, size_t numThreads, Parameter* para); - typedef std::function + typedef std::function GetTraverseCallback; void traverse(GetTraverseCallback getTraverseCallback); }; From fdd40e5528576899d5dd744eb772a79c53300d59 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 7 Sep 2016 06:58:45 +0000 Subject: [PATCH 06/10] Fix 32-bit gcc compile warnings. Change-Id: Ibc39ca1d1a27d0d28569e29f41a5647659f8c764 --- paddle/pserver/ParameterClient2.cpp | 6 +++--- paddle/pserver/ParameterServer2.cpp | 2 +- paddle/pserver/SocketChannel.cpp | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/pserver/ParameterClient2.cpp b/paddle/pserver/ParameterClient2.cpp index 07961cbdcc20c..d0e5352c828d1 100644 --- a/paddle/pserver/ParameterClient2.cpp +++ b/paddle/pserver/ParameterClient2.cpp @@ -278,7 +278,7 @@ void ParameterClient2::prepareSendData( if (sendingPara) { sendJob->parallelInputIovs[serverId].push_back( - {sendMat->getLocalRow(row), sizeof(real) * blockSize}); + {sendMat->getLocalRow(row), sizeof(real) * (size_t) blockSize}); /// detect sparse parameter distribution sparseDistribution_->probeDistribution(serverId, sizeof(real) * blockSize); @@ -302,8 +302,8 @@ void ParameterClient2::prepareSendData( block->set_begin_pos(beginDim); block->set_block_size(endDim - beginDim); if (buf) { - sendJob->parallelInputIovs[serverId].push_back( - {buf + beginDim, sizeof(real) * (endDim - beginDim)}); + sendJob->parallelInputIovs[serverId].push_back({buf + beginDim, + sizeof(real) * ((size_t) (endDim - beginDim))}); } } } diff --git a/paddle/pserver/ParameterServer2.cpp b/paddle/pserver/ParameterServer2.cpp index bb3caeb728d1c..8f72c1988d167 100644 --- a/paddle/pserver/ParameterServer2.cpp +++ b/paddle/pserver/ParameterServer2.cpp @@ -724,7 +724,7 @@ void ParameterServer2::sendBackParameter(const ParameterBlock& block, << " id=" << block.para_id() << " block id=" << block.block_id(); real* valueBuffer = vectors_[parameterType]->getPoint(offset); - outputBuffers->push_back({valueBuffer, block.block_size()}); + outputBuffers->push_back({valueBuffer, (size_t) block.block_size()}); } void ParameterServer2::sendBackParameter(const ParameterBlock& block, diff --git a/paddle/pserver/SocketChannel.cpp b/paddle/pserver/SocketChannel.cpp index ebb4245b9a7df..698473060a4c1 100644 --- a/paddle/pserver/SocketChannel.cpp +++ b/paddle/pserver/SocketChannel.cpp @@ -148,7 +148,8 @@ void SocketChannel::writeMessage(const std::vector& userIovs) { std::vector iovs; iovs.reserve(userIovs.size() + 2); iovs.push_back({&header, sizeof(header)}); - iovs.push_back({&iovLengths[0], sizeof(iovLengths[0]) * header.numIovs}); + iovs.push_back({&iovLengths[0], + sizeof(iovLengths[0]) * (size_t) header.numIovs}); iovs.insert(iovs.end(), userIovs.begin(), userIovs.end()); header.totalLength = 0; From 903d5c7ec04e4074191875abaf58b418219561f2 Mon Sep 17 00:00:00 2001 From: He Date: Wed, 7 Sep 2016 16:01:36 +0800 Subject: [PATCH 07/10] bug fix for hl_matrix_classification_error --- paddle/cuda/src/hl_cuda_matrix.cu | 33 ++++++++++-------------- paddle/math/tests/test_matrixCompare.cpp | 33 +++++++++++++++++++++--- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 15799919fa137..fc003b7d6377d 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -266,25 +266,21 @@ template __global__ void KeMatrixClassificationError(real* in_A, int* in_B, real* out_C, - int dimM, int dimN) { __shared__ real max_s[blockSize]; __shared__ int max_l[blockSize]; - int cnt = (dimN + blockSize -1) / blockSize; - int tid = threadIdx.x; - int lmt = tid; - int index = 0; - real t; + const int tid = threadIdx.x; + const int rowId = blockIdx.x; max_s[tid] = -1e30f; - for (int ii = 0; ii < cnt && lmt < dimN; ii++) { - index = blockIdx.y*dimN + lmt; - t = in_A[index]; - if (max_s[tid] < t) { - max_s[tid] = t; - max_l[tid] = lmt; + in_A += rowId * dimN; + real tmp; + for (int colId = tid; colId < dimN; colId += blockSize) { + tmp = in_A[colId]; + if (max_s[tid] < tmp) { + max_s[tid] = tmp; + max_l[tid] = colId; } - lmt += blockSize; } __syncthreads(); @@ -300,7 +296,7 @@ __global__ void KeMatrixClassificationError(real* in_A, __syncthreads(); if (tid == 0) { - out_C[blockIdx.y] = (max_l[0] == in_B[blockIdx.y] ? 0 : 1.0f); + out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f); } } @@ -313,12 +309,9 @@ void hl_matrix_classification_error(real* A_d, CHECK_NOTNULL(B_d); CHECK_NOTNULL(C_d); - int blocksX = 1; - int blocksY = dimM; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - KeMatrixClassificationError<1024><<< grid, threads, 0, STREAM_DEFAULT >>> - (A_d, B_d, C_d, dimM, dimN); + // each sample is calculated by one block + KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>> + (A_d, B_d, C_d, dimN); CHECK_SYNC("hl_matrix_classification_error"); } diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 7caade444b827..fe8eacc2efbc5 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1697,7 +1697,6 @@ TEST(Matrix, cosSimDerivate) { } } - void testParamReluForward(int height, int width, int w_height, int w_width) { MatrixPtr output = CpuMatrix::create(height, width, false, false); @@ -1736,7 +1735,6 @@ TEST(Matrix, paramReluForward) { } } - void testParamReluBackwardW(int height, int width, int w_height, int w_width) { MatrixPtr oGrad = CpuMatrix::create(height, width, false, false); @@ -1775,7 +1773,6 @@ TEST(Matrix, paramReluBackwardW) { } } - void testParamReluBackwardDiff(int height, int width, int w_height, int w_width) { MatrixPtr oGrad = CpuMatrix::create(height, width, false, false); @@ -1819,6 +1816,36 @@ TEST(Matrix, paramReluBackwardDiff) { } } +void testClassificationError(int numSamples, int dim) { + MatrixPtr cpuError = std::make_shared(numSamples, 1); + MatrixPtr gpuError = std::make_shared(numSamples, 1); + MatrixPtr cpuOutput = std::make_shared(numSamples, dim); + MatrixPtr gpuOutput = std::make_shared(numSamples, dim); + IVectorPtr cpuLabel = std::make_shared(numSamples); + IVectorPtr gpuLabel = std::make_shared(numSamples); + + cpuOutput->randomizeUniform(); + cpuLabel->rand(dim); + gpuOutput->copyFrom(*cpuOutput); + gpuLabel->copyFrom(*cpuLabel); + + cpuError->classificationError(cpuOutput, cpuLabel); + gpuError->classificationError(gpuOutput, gpuLabel); + + MatrixPtr check = std::make_shared(numSamples, 1); + check->copyFrom(*gpuError); + MatrixCheckEqual(*cpuError, *check); +} + +TEST(Matrix, classificationError) { + for (auto numSamples : {1, 10, 100, 1000, 70000}) { + for (auto dim : {1, 10, 100, 1000}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; + testClassificationError(numSamples, dim); + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); From 5547a7601b5fda56a8c7edda45ae91ec4e51cacc Mon Sep 17 00:00:00 2001 From: liuyuan04 Date: Thu, 8 Sep 2016 11:10:23 +0800 Subject: [PATCH 08/10] Refine doc of Python Prediction API, replace DataProviderWrapperConverter with DataProviderConverter. --- doc/ui/predict/predict_sample.py | 8 ++++---- doc/ui/predict/swig_py_paddle_en.rst | 30 ++++++++++++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/doc/ui/predict/predict_sample.py b/doc/ui/predict/predict_sample.py index ac16b2b48b9f7..d55d2c730dece 100644 --- a/doc/ui/predict/predict_sample.py +++ b/doc/ui/predict/predict_sample.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from py_paddle import swig_paddle, DataProviderWrapperConverter -from paddle.trainer.PyDataProviderWrapper import DenseSlot +from py_paddle import swig_paddle, DataProviderConverter +from paddle.trainer.PyDataProvider2 import dense_vector from paddle.trainer.config_parser import parse_config TEST_DATA = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -89,12 +89,12 @@ def main(): - conf = parse_config("./mnist_model/trainer_config.conf.norm", "") + conf = parse_config("./mnist_model/trainer_config.py", "") print conf.data_config.load_data_args network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config) assert isinstance(network, swig_paddle.GradientMachine) # For code hint. network.loadParameters("./mnist_model/") - converter = DataProviderWrapperConverter(False, [DenseSlot(784)]) + converter = DataProviderConverter([dense_vector(784)]) inArg = converter(TEST_DATA) print network.forwardTest(inArg) diff --git a/doc/ui/predict/swig_py_paddle_en.rst b/doc/ui/predict/swig_py_paddle_en.rst index 9841f124e25a4..b743fc4569146 100644 --- a/doc/ui/predict/swig_py_paddle_en.rst +++ b/doc/ui/predict/swig_py_paddle_en.rst @@ -10,27 +10,35 @@ SWIG. The main steps of predict values in python are: * Predict Here is a sample python script that shows the typical prediction process for the -MNIST classification problem. +MNIST classification problem. A complete sample code could be found at +:code:`src_root/doc/ui/predict/predict_sample.py`. .. literalinclude:: ./predict_sample.py :language: python - :linenos: + :lines: 15-18,90-100,101-104 The module that does the most of the job is py_paddle.swig_paddle, it's generated by SWIG and has complete documents, for more details you can use python's :code:`help()` function. Let's walk through the above python script: -* At the beginning, initialize PaddlePaddle with command line arguments(line 90). -* Parse the configuration file that is used in training(line 93). -* Create a neural network at line 95 according the parsed configuration, then - load the trained parameters from model at line 97. -* A utility class for data transformation is created at line 98. +* At the beginning, use :code:`swig_paddle.initPaddle()` to initialize + PaddlePaddle with command line arguments, for more about command line arguments + see `Command Line Arguments <../cmd_argument/detail_introduction.html>`_. +* Parse the configuration file that is used in training with :code:`parse_config()`. + Because data to predict with always have no label, and output of prediction work + normally is the output layer rather than the cost layer, so you should modify + the configuration file accordingly before using it in the prediction work. +* Create a neural network with + :code:`swig_paddle.GradientMachine.createFromConfigproto()`, which takes the + parsed configuration :code:`conf.model_config` as argument. Then load the + trained parameters from the model with :code:`network.loadParameters()`. +* Create a data converter object of utility class :code:`DataProviderConverter`. - Note: As swig_paddle can only accept C++ matrices, we offer a utility - class DataProviderWraaperConverter that can accept the same input data with - PyDataProviderWrapper, for more information please refer to document + class DataProviderConverter that can accept the same input data with + PyDataProvider2, for more information please refer to document of `PyDataProvider2 <../data_provider/pydataprovider2.html>`_. -* Do the prediction and output the result at line 100, forwardTest is another - utility class that directly takes the activations of the output layer. +* Do the prediction with :code:`forwardTest()`, which takes the converted + input data and outputs the activations of the output layer. Here is a typical output: From d6d91223b5303d0e7dbe10abdde832d904e5bb6c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 8 Sep 2016 15:26:58 +0800 Subject: [PATCH 09/10] fix docker tag mistake Change-Id: Ia84860cdc25945ececba84fc9807495c9e5f047b --- doc/build/docker_install.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/build/docker_install.md b/doc/build/docker_install.md index 997a755956bf2..3cd9d1730a22b 100644 --- a/doc/build/docker_install.md +++ b/doc/build/docker_install.md @@ -8,12 +8,12 @@ Docker is a tool designed to make it easier to create, deploy, and run applicati ### PaddlePaddle Docker images There are six Docker images: -- paddledev/paddle:latest-cpu: PaddlePaddle CPU binary image. -- paddledev/paddle:latest-gpu: PaddlePaddle GPU binary image. -- paddledev/paddle:latest-cpu-devel: PaddlePaddle CPU binary image plus source code. -- paddledev/paddle:latest-gpu-devel: PaddlePaddle GPU binary image plus source code. -- paddledev/paddle:latest-cpu-demo: PaddlePaddle CPU binary image plus source code and demo -- paddledev/paddle:latest-gpu-demo: PaddlePaddle GPU binary image plus source code and demo +- paddledev/paddle:cpu-latest: PaddlePaddle CPU binary image. +- paddledev/paddle:gpu-latest: PaddlePaddle GPU binary image. +- paddledev/paddle:cpu-devel-latest: PaddlePaddle CPU binary image plus source code. +- paddledev/paddle:gpu-devel-latest: PaddlePaddle GPU binary image plus source code. +- paddledev/paddle:cpu-demo-latest: PaddlePaddle CPU binary image plus source code and demo +- paddledev/paddle:gpu-demo-latest: PaddlePaddle GPU binary image plus source code and demo Tags with latest will be replaced by a released version. @@ -23,7 +23,7 @@ You have to install Docker in your machine which has linux kernel version 3.10+ You can use ```docker pull ```to download images first, or just launch a container with ```docker run```: ```bash -docker run -it paddledev/paddle:lastest-cpu +docker run -it paddledev/paddle:cpu-latest ``` If you want to launch container with GPU support, you need to set some environment variables at the same time: @@ -31,7 +31,7 @@ If you want to launch container with GPU support, you need to set some environme ```bash export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}" export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') -docker run -it paddledev/paddle:latest-gpu +docker run -it paddledev/paddle:gpu-latest ``` ### Notice From dbaabc94fb0b21b7bf91132eab5de954143d870b Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 7 Sep 2016 15:48:36 +0800 Subject: [PATCH 10/10] fix unitest of test_RecurrentGradientMachine, and some tiny doc update Change-Id: I028e402c964ca4f4431cbf8153bea4379dd4df70 --- doc/demo/imagenet_model/resnet_model.md | 2 +- doc/demo/rec/ml_regression.rst | 2 +- paddle/gserver/tests/sequenceGen.py | 30 +++++++++++-------- .../gserver/tests/sequence_layer_group.conf | 10 +++---- .../tests/sequence_nest_layer_group.conf | 10 +++---- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/doc/demo/imagenet_model/resnet_model.md b/doc/demo/imagenet_model/resnet_model.md index 2e5c7f324434d..5403ab9f17d23 100644 --- a/doc/demo/imagenet_model/resnet_model.md +++ b/doc/demo/imagenet_model/resnet_model.md @@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa ### C++ Interface -First, specify image data list in `define_py_data_sources` in the config, see example `demo/model_zoo/resnet/resnet.py`. +First, specify image data list in `define_py_data_sources2` in the config, see example `demo/model_zoo/resnet/resnet.py`. ``` train_list = 'train.list' if not is_test else None diff --git a/doc/demo/rec/ml_regression.rst b/doc/demo/rec/ml_regression.rst index 4917f873a934d..0c14e4f5bb7f8 100644 --- a/doc/demo/rec/ml_regression.rst +++ b/doc/demo/rec/ml_regression.rst @@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers * Text Convolution Pooling Layer, `text_conv_pool <../../ui/api/trainer_config_helpers/networks.html #trainer_config_helpers.networks.text_conv_pool>`_ -* Declare Python Data Sources, `define_py_data_sources +* Declare Python Data Sources, `define_py_data_sources2 <../../ui/api/trainer_config_helpers/data_sources.html>`_ Data Provider diff --git a/paddle/gserver/tests/sequenceGen.py b/paddle/gserver/tests/sequenceGen.py index dd2b90dd4986c..e4727e472d446 100644 --- a/paddle/gserver/tests/sequenceGen.py +++ b/paddle/gserver/tests/sequenceGen.py @@ -18,27 +18,33 @@ import os import sys -from paddle.trainer.PyDataProviderWrapper import * +from paddle.trainer.PyDataProvider2 import * -@init_hook_wrapper -def hook(obj, dict_file, **kwargs): - obj.word_dict = dict_file - obj.slots = [IndexSlot(len(obj.word_dict)), IndexSlot(3)] - obj.logger.info('dict len : %d' % (len(obj.word_dict))) +def hook(settings, dict_file, **kwargs): + settings.word_dict = dict_file + settings.input_types = [integer_value_sequence(len(settings.word_dict)), + integer_value_sequence(3)] + settings.logger.info('dict len : %d' % (len(settings.word_dict))) -@provider(use_seq=True, init_hook=hook) -def process(obj, file_name): +@provider(init_hook=hook) +def process(settings, file_name): with open(file_name, 'r') as fdata: for line in fdata: label, comment = line.strip().split('\t') label = int(''.join(label.split())) words = comment.split() - word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] + word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict] yield word_slot, [label] ## for hierarchical sequence network -@provider(use_seq=True, init_hook=hook) -def process2(obj, file_name): +def hook2(settings, dict_file, **kwargs): + settings.word_dict = dict_file + settings.input_types = [integer_value_sub_sequence(len(settings.word_dict)), + integer_value_sub_sequence(3)] + settings.logger.info('dict len : %d' % (len(settings.word_dict))) + +@provider(init_hook=hook2) +def process2(settings, file_name): with open(file_name) as fdata: label_list = [] word_slot_list = [] @@ -47,7 +53,7 @@ def process2(obj, file_name): label,comment = line.strip().split('\t') label = int(''.join(label.split())) words = comment.split() - word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict] + word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict] label_list.append([label]) word_slot_list.append(word_slot) else: diff --git a/paddle/gserver/tests/sequence_layer_group.conf b/paddle/gserver/tests/sequence_layer_group.conf index 9ad2b3762845f..ac031b31280df 100644 --- a/paddle/gserver/tests/sequence_layer_group.conf +++ b/paddle/gserver/tests/sequence_layer_group.conf @@ -21,11 +21,11 @@ dict_file = dict() for line_count, line in enumerate(open(dict_path, "r")): dict_file[line.strip()] = line_count -define_py_data_sources(train_list='gserver/tests/Sequence/train.list', - test_list=None, - module='sequenceGen', - obj='process', - args={"dict_file":dict_file}) +define_py_data_sources2(train_list='gserver/tests/Sequence/train.list', + test_list=None, + module='sequenceGen', + obj='process', + args={"dict_file":dict_file}) settings(batch_size=5) ######################## network configure ################################ diff --git a/paddle/gserver/tests/sequence_nest_layer_group.conf b/paddle/gserver/tests/sequence_nest_layer_group.conf index 8c3a08f16cd1c..38c60b657b969 100644 --- a/paddle/gserver/tests/sequence_nest_layer_group.conf +++ b/paddle/gserver/tests/sequence_nest_layer_group.conf @@ -21,11 +21,11 @@ dict_file = dict() for line_count, line in enumerate(open(dict_path, "r")): dict_file[line.strip()] = line_count -define_py_data_sources(train_list='gserver/tests/Sequence/train.list.nest', - test_list=None, - module='sequenceGen', - obj='process2', - args={"dict_file":dict_file}) +define_py_data_sources2(train_list='gserver/tests/Sequence/train.list.nest', + test_list=None, + module='sequenceGen', + obj='process2', + args={"dict_file":dict_file}) settings(batch_size=2) ######################## network configure ################################