-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate warp-ctc as WarpCTCLayer, including unitest and layer inter… #651
Changes from 1 commit
4d487c6
5a97c98
a816443
25f1fbc
18b85e5
46ef2bc
7bb7fed
78bdd32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,11 @@ endif() | |
if(NOT WITH_GPU) | ||
add_definitions(-DPADDLE_ONLY_CPU) | ||
add_definitions(-DHPPL_STUB_FUNC) | ||
|
||
if(WITH_DSO) | ||
add_definitions(-DPADDLE_USE_DSO) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请问DSO是指? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DSO是动态装载so的简写, @gangliao 实现用来装载cuda相关的一些库,所以之前只在with_gpu的情况下使用。warp-ctc使用dso这种方式,由于warp-ctc存在cpu-gpu以及cpu-only版本,所以做了些修改,使得paddle在cpu-only的版本中也支持dso这个功能。 |
||
endif(WITH_DSO) | ||
|
||
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) | ||
else() | ||
if(${CUDA_VERSION_MAJOR} GREATER 6) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,10 +18,6 @@ limitations under the License. */ | |
#include <dlfcn.h> | ||
#include <string> | ||
#include <memory> | ||
#include <cuda_runtime.h> | ||
#include <cublas_v2.h> | ||
#include <curand.h> | ||
#include <cudnn.h> | ||
#include "hl_base.h" | ||
|
||
/** | ||
|
@@ -56,4 +52,12 @@ void GetCudartDsoHandle(void** dso_handle); | |
*/ | ||
void GetCurandDsoHandle(void** dso_handle); | ||
|
||
/** | ||
* @brief load the DSO of warp-ctc | ||
* | ||
* @param **dso_handle dso handler | ||
* | ||
*/ | ||
void GetWarpctcDsoHandle(void** dso_handle); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CTC是缩写,应该大写。DSO貌似也是一个缩写?所以是 GetWarpCTCDSOHandle There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gangliao DSO这个要统一改成大写吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不用改吧 我觉得Dso还挺常见 类似于tf loader, 当时我来的时候,由于只有linux,所以用了so命名, 而没有考虑DLL一类的名字。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果是缩写,则应该都大写。 |
||
|
||
#endif // HL_DSO_LOADER_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#ifndef HL_WARPCTC_WRAP_H_ | ||
#define HL_WARPCTC_WRAP_H_ | ||
|
||
#include "hl_base.h" | ||
/// #include "hl_cuda.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果是不要的代码,请删除,而不是注释掉。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的。 |
||
#include "warp-ctc/include/ctc.h" | ||
|
||
typedef ctcStatus_t hl_warpctc_status_t; | ||
typedef ctcOptions hl_warpctc_options_t; | ||
|
||
/** | ||
* @brief Init ctc options. | ||
* | ||
* @param[in] blank blank label used in ctc loss function. | ||
* @param[in] useGpu whether use gpu. | ||
* @param[out] options handle to store cpu or gpu informations. | ||
* | ||
*/ | ||
extern void hl_warpctc_init(const size_t blank, | ||
bool useGpu, | ||
hl_warpctc_options_t* options); | ||
|
||
/** | ||
* @brief Compute the connectionist temporal classification loss, | ||
* and optionally compute the gradient with respect to the inputs. | ||
* | ||
* if batchGrad == nullptr | ||
* | ||
* only compute the ctc loss. | ||
* | ||
* if batchGrad != nullptr | ||
* | ||
* compute both ctc loss and gradient. | ||
* | ||
* @param[in] batchInput batch matrix of input probabilities, | ||
* in maxSequenceLength x numSequence x numClasses | ||
* (row-major) format. | ||
* @param[out] batchGrad batch matrix of gradient. | ||
* @param[in] cpuLabels labels always in CPU memory. | ||
* @param[in] cpuLabelLengths length of all labels in CPU memory. | ||
* @param[in] cpuInputLengths length of all sequences in CPU memory. | ||
* @param[in] numClasses number of possible output symbols. | ||
* @param[in] numSequences number of sequence. | ||
* @param[out] cpuCosts cost of each sequence in CPU memory. | ||
* @param[out] workspace workspace to store some temporary results. | ||
* @param[in] options handle to store cpu or gpu informations. | ||
* | ||
*/ | ||
extern void hl_warpctc_compute_loss(const real* batchInput, | ||
real* batchGrad, | ||
const int* cpuLabels, | ||
const int* cpuLabelLengths, | ||
const int* cpuInputLengths, | ||
const size_t numClasses, | ||
const size_t numSequences, | ||
real* cpuCosts, | ||
void* workspace, | ||
hl_warpctc_options_t* options); | ||
|
||
/** | ||
* @brief Compute the required workspace size. | ||
* There is no memory allocated operations within warp-ctc. | ||
* | ||
* @param[in] cpuLabelLengths length of all labels in CPU memory. | ||
* @param[in] cpuInputLengths length of all sequences in CPU memory. | ||
* @param[in] numClasses number of possible output symbols. | ||
* @param[in] numSequences number of sequence. | ||
* @param[in] options handle to store cpu or gpu informations. | ||
* @param[out] bytes pointer to a scalar where the memory | ||
* requirement in bytes will be placed. | ||
* | ||
*/ | ||
extern void hl_warpctc_get_workspace_size(const int* cpuLabelLengths, | ||
const int* cpuInputLengths, | ||
const size_t numClasses, | ||
const size_t numSequences, | ||
hl_warpctc_options_t* options, | ||
size_t* bytes); | ||
|
||
#endif // HL_WARPCTC_WRAP_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -447,6 +447,124 @@ void hl_sequence2batch_add(real *batch, | |
CHECK_SYNC("hl_sequence2batch_add failed"); | ||
} | ||
|
||
template<bool normByTimes, bool seq2batch> | ||
__global__ | ||
void KeSequence2BatchPadding(real* batch, | ||
real* sequence, | ||
const int* sequenceStartPositions, | ||
const size_t sequenceWidth, | ||
const size_t maxSequenceLength, | ||
const size_t numSequences) { | ||
int batchIdx = blockIdx.y; | ||
int sequenceStart = sequenceStartPositions[batchIdx]; | ||
int sequenceLength = sequenceStartPositions[batchIdx + 1] - sequenceStart; | ||
|
||
int sequenceIdx = blockIdx.x * blockDim.y + threadIdx.y; | ||
int batchBaseIdx = (sequenceIdx * numSequences + batchIdx) * sequenceWidth; | ||
int sequenceBaseIdx = (sequenceStart + sequenceIdx) * sequenceWidth; | ||
|
||
if (sequenceIdx < sequenceLength) { | ||
if (seq2batch) { | ||
/* sequence -> batch */ | ||
if (normByTimes) { | ||
real scale = 1.0f / (real)sequenceLength; | ||
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) { | ||
batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i]; | ||
} | ||
} else { | ||
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) { | ||
batch[batchBaseIdx + i] = sequence[sequenceBaseIdx + i]; | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以考虑将469-477行都合并成一个for循环:(下同) |
||
} else { | ||
/* batch -> sequence */ | ||
if (normByTimes) { | ||
real scale = 1.0f / (real)sequenceLength; | ||
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) { | ||
sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i]; | ||
} | ||
} else { | ||
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) { | ||
sequence[sequenceBaseIdx + i] = batch[batchBaseIdx + i]; | ||
} | ||
} | ||
} | ||
} else if (sequenceIdx < maxSequenceLength) { | ||
if (seq2batch) { | ||
/* sequence -> batch */ | ||
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) { | ||
batch[batchBaseIdx + i] = 0; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void hl_sequence2batch_copy_padding(real* batch, | ||
real* sequence, | ||
const int* sequenceStartPositions, | ||
const size_t sequenceWidth, | ||
const size_t maxSequenceLength, | ||
const size_t numSequences, | ||
bool normByTimes, | ||
bool seq2batch) { | ||
CHECK_NOTNULL(batch); | ||
CHECK_NOTNULL(sequence); | ||
CHECK_NOTNULL(sequenceStartPositions); | ||
|
||
if (!normByTimes && numSequences == 1) { | ||
size_t elementCount = maxSequenceLength * sequenceWidth; | ||
if (seq2batch) { | ||
/* sequence -> batch */ | ||
hl_memcpy_device2device(batch, sequence, sizeof(real) * elementCount); | ||
} else { | ||
/* batch -> sequence */ | ||
hl_memcpy_device2device(sequence, batch, sizeof(real) * elementCount); | ||
} | ||
return; | ||
} | ||
|
||
const int CUDA_BLOCK_SIZE = 512; | ||
|
||
/* At least use 32 threads to copy sequenceWidth elements, | ||
and at least 8 elements for each thread. */ | ||
int blockDimX = ((((sequenceWidth + 7) >> 3) + 31) >> 5) << 5; | ||
blockDimX = (blockDimX < CUDA_BLOCK_SIZE) ? blockDimX : CUDA_BLOCK_SIZE; | ||
|
||
int blockDimY = CUDA_BLOCK_SIZE / blockDimX; | ||
dim3 threads(blockDimX, blockDimY); | ||
|
||
int gridDimX = (maxSequenceLength * blockDimX + CUDA_BLOCK_SIZE - 1) / | ||
CUDA_BLOCK_SIZE; | ||
int gridDimY = numSequences; | ||
dim3 grid(gridDimX, gridDimY); | ||
|
||
if (seq2batch) { | ||
/* sequence -> batch */ | ||
if (normByTimes) { | ||
KeSequence2BatchPadding<1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>( | ||
batch, sequence, sequenceStartPositions, | ||
sequenceWidth, maxSequenceLength, numSequences); | ||
} else { | ||
KeSequence2BatchPadding<0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>( | ||
batch, sequence, sequenceStartPositions, | ||
sequenceWidth, maxSequenceLength, numSequences); | ||
} | ||
} else { | ||
/* batch -> sequence */ | ||
if (normByTimes) { | ||
KeSequence2BatchPadding<1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>( | ||
batch, sequence, sequenceStartPositions, | ||
sequenceWidth, maxSequenceLength, numSequences); | ||
} else { | ||
KeSequence2BatchPadding<0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>( | ||
batch, sequence, sequenceStartPositions, | ||
sequenceWidth, maxSequenceLength, numSequences); | ||
} | ||
} | ||
|
||
CHECK_SYNC("hl_sequence2batch_copy_padding failed"); | ||
} | ||
|
||
__device__ inline float my_rsqrt(float x) { | ||
return rsqrtf(x); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和134行重复了,可以提出去单独写了。