-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement cuda kernel for index_sample. #30380
Changes from 6 commits
b04cc91
866dd28
d18b4bf
170e5e1
98a9af7
fec47c5
d931a4a
ddd52bd
ce289f4
aadcd13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,18 +12,210 @@ | |||||||||
// See the License for the specific language governing permissions and | ||||||||||
// limitations under the License. | ||||||||||
|
||||||||||
#include "paddle/fluid/framework/op_registry.h" | ||||||||||
#include "paddle/fluid/operators/index_sample_op.h" | ||||||||||
#include "paddle/fluid/platform/cuda_device_function.h" | ||||||||||
#include "paddle/fluid/platform/cuda_primitives.h" | ||||||||||
|
||||||||||
namespace paddle { | ||||||||||
namespace operators { | ||||||||||
|
||||||||||
using platform::PADDLE_CUDA_NUM_THREADS; | ||||||||||
using Tensor = framework::Tensor; | ||||||||||
using LoDTensor = framework::LoDTensor; | ||||||||||
|
||||||||||
template <typename T, typename IndexT = int> | ||||||||||
__global__ void IndexSampleForward(const IndexT* index, const T* in_data, | ||||||||||
T* out_data, size_t index_length, | ||||||||||
size_t input_length, size_t batch_size) { | ||||||||||
int ix = blockDim.x * blockIdx.x + threadIdx.x; | ||||||||||
int iy = blockDim.y * blockIdx.y + threadIdx.y; | ||||||||||
int tid = iy * index_length + ix; | ||||||||||
int tid_x = iy * input_length + ix; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 变量名可以起的更直观一些,比如你这里的 |
||||||||||
|
||||||||||
if (ix < index_length & iy < batch_size) { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BlockDim.x最小值为32。当index_length<32时,一个block里面连续的32个线程会有空闲?后续可以再看看有没有更好的并行方案。 |
||||||||||
IndexT idx = index[tid]; | ||||||||||
out_data[tid] = in_data[tid_x - ix + idx]; | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
template <typename T, typename IndexT = int> | ||||||||||
__global__ void IndexSampleGradDefault(const IndexT* index, T* in_grad, | ||||||||||
const T* out_grad, size_t index_length, | ||||||||||
size_t input_length, size_t batch_size) { | ||||||||||
int ix = blockDim.x * blockIdx.x + threadIdx.x; | ||||||||||
int iy = blockDim.y * blockIdx.y + threadIdx.y; | ||||||||||
int tid = iy * index_length + ix; | ||||||||||
int tid_y = iy * input_length + ix; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 变量名命名建议,同上。为什么前向的kernel里面叫 |
||||||||||
|
||||||||||
if (ix < index_length & iy < batch_size) { | ||||||||||
IndexT idx = index[tid]; | ||||||||||
platform::CudaAtomicAdd(&(in_grad[tid_y - ix + idx]), out_grad[tid]); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
template <typename T, typename IndexT = int> | ||||||||||
__global__ void IndexSampleGradSpecial(const IndexT* index, T* in_grad, | ||||||||||
const T* out_grad, size_t index_length, | ||||||||||
size_t input_length, size_t batch_size) { | ||||||||||
int ix = blockDim.x * blockIdx.x + threadIdx.x; | ||||||||||
int iy = blockDim.y * blockIdx.y + threadIdx.y; | ||||||||||
int tid = iy * index_length + ix; | ||||||||||
int tid_y = iy * input_length + ix; | ||||||||||
|
||||||||||
if (ix < index_length & iy < batch_size) { | ||||||||||
IndexT idx = index[tid]; | ||||||||||
in_grad[tid_y - ix + idx] = out_grad[tid]; | ||||||||||
} | ||||||||||
} | ||||||||||
zh794390558 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
template <typename DeviceContext, typename T> | ||||||||||
class IndexSampleCUDAKernel : public framework::OpKernel<T> { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里建议改成特化index_sample.h中 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个没有改?我是建议改成如下方式: Paddle/paddle/fluid/operators/sum_op.cu Lines 230 to 233 in f89da4a
这样L92的检查就可以去掉了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经按要求修改 |
||||||||||
public: | ||||||||||
void Compute(const framework::ExecutionContext& ctx) const override { | ||||||||||
auto* input = ctx.Input<LoDTensor>("X"); | ||||||||||
auto* index = ctx.Input<LoDTensor>("Index"); | ||||||||||
auto* output = ctx.Output<LoDTensor>("Out"); | ||||||||||
|
||||||||||
const auto& index_type = index->type(); | ||||||||||
bool index_type_match = index_type == framework::proto::VarType::INT64 || | ||||||||||
index_type == framework::proto::VarType::INT32; | ||||||||||
PADDLE_ENFORCE_EQ(index_type_match, true, | ||||||||||
platform::errors::InvalidArgument( | ||||||||||
"Input(Index) holds the wrong type, it holds %s, but " | ||||||||||
"desires to be %s or %s", | ||||||||||
paddle::framework::DataTypeToString(index_type), | ||||||||||
paddle::framework::DataTypeToString( | ||||||||||
framework::proto::VarType::INT32), | ||||||||||
paddle::framework::DataTypeToString( | ||||||||||
framework::proto::VarType::INT64))); | ||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||
platform::is_gpu_place(ctx.GetPlace()), true, | ||||||||||
platform::errors::InvalidArgument("It must use CUDAPlace.")); | ||||||||||
|
||||||||||
const auto* in_data = input->data<T>(); | ||||||||||
auto* out_data = output->mutable_data<T>(ctx.GetPlace()); | ||||||||||
auto stream = | ||||||||||
ctx.template device_context<platform::CUDADeviceContext>().stream(); | ||||||||||
|
||||||||||
auto input_dim = input->dims(); | ||||||||||
auto index_dim = index->dims(); | ||||||||||
size_t batch_size = input_dim[0]; | ||||||||||
size_t input_length = input_dim[1]; | ||||||||||
size_t index_length = index_dim[1]; | ||||||||||
|
||||||||||
auto block_width = platform::RoundToPowerOfTwo(index_length); | ||||||||||
int block_height = | ||||||||||
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; | ||||||||||
|
||||||||||
dim3 block_dim(block_width, block_height); | ||||||||||
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, | ||||||||||
(batch_size + block_dim.y - 1) / block_dim.y); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
|
||||||||||
if (index_type == framework::proto::VarType::INT64) { | ||||||||||
const int64_t* index_data = index->data<int64_t>(); | ||||||||||
IndexSampleForward<T, int64_t><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||
index_data, in_data, out_data, index_length, input_length, | ||||||||||
batch_size); | ||||||||||
} else if (index_type == framework::proto::VarType::INT32) { | ||||||||||
const int* index_data = index->data<int>(); | ||||||||||
IndexSampleForward<T, int><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||
index_data, in_data, out_data, index_length, input_length, | ||||||||||
batch_size); | ||||||||||
} | ||||||||||
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op实现里面不用同步。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除同步操作 |
||||||||||
} | ||||||||||
}; | ||||||||||
|
||||||||||
template <typename DeviceContext, typename T> | ||||||||||
class IndexSampleGradCUDAKernel : public framework::OpKernel<T> { | ||||||||||
public: | ||||||||||
void Compute(const framework::ExecutionContext& ctx) const override { | ||||||||||
auto* output_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out")); | ||||||||||
auto* input_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); | ||||||||||
auto* index = ctx.Input<LoDTensor>("Index"); | ||||||||||
|
||||||||||
const auto* output_grad_data = output_grad->data<T>(); | ||||||||||
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); | ||||||||||
|
||||||||||
const auto& index_type = index->type(); | ||||||||||
bool index_type_match = index_type == framework::proto::VarType::INT64 || | ||||||||||
index_type == framework::proto::VarType::INT32; | ||||||||||
PADDLE_ENFORCE_EQ(index_type_match, true, | ||||||||||
platform::errors::InvalidArgument( | ||||||||||
"Input(Index) holds the wrong type, it holds %s, but " | ||||||||||
"desires to be %s or %s", | ||||||||||
paddle::framework::DataTypeToString(index_type), | ||||||||||
paddle::framework::DataTypeToString( | ||||||||||
framework::proto::VarType::INT32), | ||||||||||
paddle::framework::DataTypeToString( | ||||||||||
framework::proto::VarType::INT64))); | ||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||
platform::is_gpu_place(ctx.GetPlace()), true, | ||||||||||
platform::errors::InvalidArgument("It must use CUDAPlace.")); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个检查可以删掉了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除检查判断 |
||||||||||
|
||||||||||
auto stream = | ||||||||||
ctx.template device_context<platform::CUDADeviceContext>().stream(); | ||||||||||
|
||||||||||
auto input_num = input_grad->numel(); | ||||||||||
auto input_dim = input_grad->dims(); | ||||||||||
auto index_dim = index->dims(); | ||||||||||
size_t batch_size = index_dim[0]; | ||||||||||
size_t input_length = input_dim[1]; | ||||||||||
size_t index_length = index_dim[1]; | ||||||||||
|
||||||||||
auto block_width = platform::RoundToPowerOfTwo(index_length); | ||||||||||
auto block_height = | ||||||||||
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; | ||||||||||
|
||||||||||
dim3 block_dim(block_width, block_height); | ||||||||||
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, | ||||||||||
(batch_size + block_dim.y - 1) / block_dim.y); | ||||||||||
|
||||||||||
platform::GpuMemsetAsync(input_grad_data, 0, | ||||||||||
sizeof(T) * input_length * batch_size, stream); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里改成调用如下函数: Paddle/paddle/fluid/operators/trace_op.h Lines 219 to 221 in f89da4a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||||||||||
|
||||||||||
if (index_type == framework::proto::VarType::INT64) { | ||||||||||
const int64_t* index_data = index->data<int64_t>(); | ||||||||||
if (index_length == 1) { | ||||||||||
IndexSampleGradSpecial<T, int64_t><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||
index_data, input_grad_data, output_grad_data, index_length, | ||||||||||
input_length, batch_size); | ||||||||||
} else { | ||||||||||
IndexSampleGradDefault<T, int64_t><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||
index_data, input_grad_data, output_grad_data, index_length, | ||||||||||
input_length, batch_size); | ||||||||||
} | ||||||||||
} else if (index_type == framework::proto::VarType::INT32) { | ||||||||||
const int* index_data = index->data<int>(); | ||||||||||
if (index_length == 1) { | ||||||||||
IndexSampleGradSpecial<T, int><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||
index_data, input_grad_data, output_grad_data, index_length, | ||||||||||
input_length, batch_size); | ||||||||||
} else { | ||||||||||
IndexSampleGradDefault<T, int><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||
index_data, input_grad_data, output_grad_data, index_length, | ||||||||||
input_length, batch_size); | ||||||||||
} | ||||||||||
} | ||||||||||
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op实现里面不用同步。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后续已删除 |
||||||||||
} | ||||||||||
}; | ||||||||||
|
||||||||||
} // namespace operators | ||||||||||
} // namespace paddle | ||||||||||
|
||||||||||
namespace ops = paddle::operators; | ||||||||||
REGISTER_OP_CUDA_KERNEL( | ||||||||||
index_sample, | ||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int64_t>); | ||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>); | ||||||||||
REGISTER_OP_CUDA_KERNEL( | ||||||||||
index_sample_grad, | ||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int64_t>); | ||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, | ||||||||||
int64_t>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个实际没有用到?