-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement cuda kernel for index_sample. #30380
Changes from 3 commits
b04cc91
866dd28
d18b4bf
170e5e1
98a9af7
fec47c5
d931a4a
ddd52bd
ce289f4
aadcd13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,18 +12,225 @@ | |||||||||||||||||||||||||||||||
// See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||
// limitations under the License. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#include "paddle/fluid/framework/op_registry.h" | ||||||||||||||||||||||||||||||||
#include "paddle/fluid/operators/index_sample_op.h" | ||||||||||||||||||||||||||||||||
#include "paddle/fluid/platform/cuda_primitives.h" | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
namespace paddle { | ||||||||||||||||||||||||||||||||
namespace operators { | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
using platform::PADDLE_CUDA_NUM_THREADS; | ||||||||||||||||||||||||||||||||
using Tensor = framework::Tensor; | ||||||||||||||||||||||||||||||||
using LoDTensor = framework::LoDTensor; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
template <typename T, typename IndexT = int> | ||||||||||||||||||||||||||||||||
__global__ void index_kernel(const IndexT* p_index, const T* p_input, | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 代码需遵循Google C++编程风格,函数命名为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 编码规范确实会修改,下一次PR中这个问题会被处理掉。 |
||||||||||||||||||||||||||||||||
T* p_output, size_t stride_index, | ||||||||||||||||||||||||||||||||
size_t stride_input, size_t height) { | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 指针加 "p_" 前缀是长期保持的习惯,后续修改成与paddle贴合的命名规范。 |
||||||||||||||||||||||||||||||||
int ix = blockDim.x * blockIdx.x + threadIdx.x; | ||||||||||||||||||||||||||||||||
int iy = blockDim.y * blockIdx.y + threadIdx.y; | ||||||||||||||||||||||||||||||||
int tid = iy * stride_index + ix; | ||||||||||||||||||||||||||||||||
int tid_x = iy * stride_input + ix; | ||||||||||||||||||||||||||||||||
int tid_y = iy * stride_index + ix; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (ix < stride_index & iy < height) { | ||||||||||||||||||||||||||||||||
IndexT idx = p_index[tid]; | ||||||||||||||||||||||||||||||||
p_output[tid_y] = p_input[tid_x - ix + idx]; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
template <typename T, typename IndexT = int> | ||||||||||||||||||||||||||||||||
__global__ void index_kernel_grad(const IndexT* p_index, T* p_input, | ||||||||||||||||||||||||||||||||
const T* p_output, size_t stride_index, | ||||||||||||||||||||||||||||||||
size_t stride_input, size_t height) { | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 从实际含义上来讲:p_index -> index,p_input -> in_grad,p_output -> out_grad There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改 |
||||||||||||||||||||||||||||||||
int ix = blockDim.x * blockIdx.x + threadIdx.x; | ||||||||||||||||||||||||||||||||
int iy = blockDim.y * blockIdx.y + threadIdx.y; | ||||||||||||||||||||||||||||||||
int tid = iy * stride_index + ix; | ||||||||||||||||||||||||||||||||
int tid_y = iy * stride_input + ix; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (ix < stride_index & iy < height) { | ||||||||||||||||||||||||||||||||
for (int i = 0; i < stride_index; ++i) { | ||||||||||||||||||||||||||||||||
if (ix == i) { | ||||||||||||||||||||||||||||||||
IndexT idx = p_index[tid]; | ||||||||||||||||||||||||||||||||
T tmp1 = p_output[tid]; | ||||||||||||||||||||||||||||||||
T tmp2 = p_input[tid_y - ix + idx]; | ||||||||||||||||||||||||||||||||
p_input[tid_y - ix + idx] = tmp1 + tmp2; | ||||||||||||||||||||||||||||||||
JamesLim-sy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
zh794390558 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
template <typename DeviceContext, typename T> | ||||||||||||||||||||||||||||||||
class IndexSampleCUDAKernel : public framework::OpKernel<T> { | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里建议改成特化index_sample.h中 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个没有改?我是建议改成如下方式: Paddle/paddle/fluid/operators/sum_op.cu Lines 230 to 233 in f89da4a
这样L92的检查就可以去掉了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经按要求修改 |
||||||||||||||||||||||||||||||||
public: | ||||||||||||||||||||||||||||||||
void Compute(const framework::ExecutionContext& ctx) const override { | ||||||||||||||||||||||||||||||||
auto* input = ctx.Input<LoDTensor>("X"); | ||||||||||||||||||||||||||||||||
auto* index = ctx.Input<LoDTensor>("Index"); | ||||||||||||||||||||||||||||||||
auto* output = ctx.Output<LoDTensor>("Out"); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
const auto& index_type = index->type(); | ||||||||||||||||||||||||||||||||
bool index_type_match = index_type == framework::proto::VarType::INT64 || | ||||||||||||||||||||||||||||||||
index_type == framework::proto::VarType::INT32; | ||||||||||||||||||||||||||||||||
PADDLE_ENFORCE_EQ(index_type_match, true, | ||||||||||||||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||||||||||||||
"Input(Index) holds the wrong type, it holds %s, but " | ||||||||||||||||||||||||||||||||
"desires to be %s or %s", | ||||||||||||||||||||||||||||||||
paddle::framework::DataTypeToString(index_type), | ||||||||||||||||||||||||||||||||
paddle::framework::DataTypeToString( | ||||||||||||||||||||||||||||||||
framework::proto::VarType::INT32), | ||||||||||||||||||||||||||||||||
paddle::framework::DataTypeToString( | ||||||||||||||||||||||||||||||||
framework::proto::VarType::INT64))); | ||||||||||||||||||||||||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||||||||||||||||||||||||
platform::is_gpu_place(ctx.GetPlace()), true, | ||||||||||||||||||||||||||||||||
platform::errors::InvalidArgument("It must use CUDAPlace.")); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
auto ComputeBlockSize = [](int col) { | ||||||||||||||||||||||||||||||||
if (col > 512) | ||||||||||||||||||||||||||||||||
return 1024; | ||||||||||||||||||||||||||||||||
else if (col > 256) | ||||||||||||||||||||||||||||||||
return 512; | ||||||||||||||||||||||||||||||||
else if (col > 128) | ||||||||||||||||||||||||||||||||
return 256; | ||||||||||||||||||||||||||||||||
else if (col > 64) | ||||||||||||||||||||||||||||||||
return 128; | ||||||||||||||||||||||||||||||||
else if (col > 32) | ||||||||||||||||||||||||||||||||
return 64; | ||||||||||||||||||||||||||||||||
else if (col > 16) | ||||||||||||||||||||||||||||||||
return 32; | ||||||||||||||||||||||||||||||||
else if (col > 8) | ||||||||||||||||||||||||||||||||
return 16; | ||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||
return 8; | ||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Paddle/paddle/fluid/platform/cuda_device_function.h Lines 36 to 50 in 7e9f336
可使用这个函数代替吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以的,这块的写法非常不美观,肯定替换掉。 |
||||||||||||||||||||||||||||||||
const auto* in_data = input->data<T>(); | ||||||||||||||||||||||||||||||||
auto* out_data = output->mutable_data<T>(ctx.GetPlace()); | ||||||||||||||||||||||||||||||||
auto stream = | ||||||||||||||||||||||||||||||||
ctx.template device_context<platform::CUDADeviceContext>().stream(); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
auto input_dim = input->dims(); | ||||||||||||||||||||||||||||||||
auto index_dim = index->dims(); | ||||||||||||||||||||||||||||||||
size_t batch_size = input_dim[0]; | ||||||||||||||||||||||||||||||||
size_t input_length = input_dim[1]; | ||||||||||||||||||||||||||||||||
size_t index_length = index_dim[1]; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
auto block_width = ComputeBlockSize(index_length); | ||||||||||||||||||||||||||||||||
int block_height = | ||||||||||||||||||||||||||||||||
ComputeBlockSize(index_length * batch_size) / block_width; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
dim3 block_dim(block_width, block_height); | ||||||||||||||||||||||||||||||||
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, | ||||||||||||||||||||||||||||||||
(batch_size + block_dim.y - 1) / block_dim.y); | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (index_type == framework::proto::VarType::INT64) { | ||||||||||||||||||||||||||||||||
const int64_t* index_data = index->data<int64_t>(); | ||||||||||||||||||||||||||||||||
index_kernel<T, int64_t><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||||||||||||||||||||||||
index_data, in_data, out_data, index_length, input_length, | ||||||||||||||||||||||||||||||||
batch_size); | ||||||||||||||||||||||||||||||||
} else if (index_type == framework::proto::VarType::INT32) { | ||||||||||||||||||||||||||||||||
const int* index_data = index->data<int>(); | ||||||||||||||||||||||||||||||||
index_kernel<T, int><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||||||||||||||||||||||||
index_data, in_data, out_data, index_length, input_length, | ||||||||||||||||||||||||||||||||
batch_size); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op实现里面不用同步。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除同步操作 |
||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
template <typename DeviceContext, typename T> | ||||||||||||||||||||||||||||||||
class IndexSampleGradCUDAKernel : public framework::OpKernel<T> { | ||||||||||||||||||||||||||||||||
public: | ||||||||||||||||||||||||||||||||
void Compute(const framework::ExecutionContext& ctx) const override { | ||||||||||||||||||||||||||||||||
auto* output_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out")); | ||||||||||||||||||||||||||||||||
auto* input_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); | ||||||||||||||||||||||||||||||||
auto* index = ctx.Input<LoDTensor>("Index"); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
const auto* output_grad_data = output_grad->data<T>(); | ||||||||||||||||||||||||||||||||
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
const auto& index_type = index->type(); | ||||||||||||||||||||||||||||||||
bool index_type_match = index_type == framework::proto::VarType::INT64 || | ||||||||||||||||||||||||||||||||
index_type == framework::proto::VarType::INT32; | ||||||||||||||||||||||||||||||||
PADDLE_ENFORCE_EQ(index_type_match, true, | ||||||||||||||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||||||||||||||
"Input(Index) holds the wrong type, it holds %s, but " | ||||||||||||||||||||||||||||||||
"desires to be %s or %s", | ||||||||||||||||||||||||||||||||
paddle::framework::DataTypeToString(index_type), | ||||||||||||||||||||||||||||||||
paddle::framework::DataTypeToString( | ||||||||||||||||||||||||||||||||
framework::proto::VarType::INT32), | ||||||||||||||||||||||||||||||||
paddle::framework::DataTypeToString( | ||||||||||||||||||||||||||||||||
framework::proto::VarType::INT64))); | ||||||||||||||||||||||||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||||||||||||||||||||||||
platform::is_gpu_place(ctx.GetPlace()), true, | ||||||||||||||||||||||||||||||||
platform::errors::InvalidArgument("It must use CUDAPlace.")); | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个检查可以删掉了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除检查判断 |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
auto ComputeBlockSize = [](int col) { | ||||||||||||||||||||||||||||||||
if (col > 512) | ||||||||||||||||||||||||||||||||
return 1024; | ||||||||||||||||||||||||||||||||
else if (col > 256) | ||||||||||||||||||||||||||||||||
return 512; | ||||||||||||||||||||||||||||||||
else if (col > 128) | ||||||||||||||||||||||||||||||||
return 256; | ||||||||||||||||||||||||||||||||
else if (col > 64) | ||||||||||||||||||||||||||||||||
return 128; | ||||||||||||||||||||||||||||||||
else if (col > 32) | ||||||||||||||||||||||||||||||||
return 64; | ||||||||||||||||||||||||||||||||
else if (col > 16) | ||||||||||||||||||||||||||||||||
return 32; | ||||||||||||||||||||||||||||||||
else if (col > 8) | ||||||||||||||||||||||||||||||||
return 16; | ||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||
return 8; | ||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||
auto stream = | ||||||||||||||||||||||||||||||||
ctx.template device_context<platform::CUDADeviceContext>().stream(); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
auto input_num = input_grad->numel(); | ||||||||||||||||||||||||||||||||
auto input_dim = input_grad->dims(); | ||||||||||||||||||||||||||||||||
auto index_dim = index->dims(); | ||||||||||||||||||||||||||||||||
size_t batch_size = index_dim[0]; | ||||||||||||||||||||||||||||||||
size_t input_length = input_dim[1]; | ||||||||||||||||||||||||||||||||
size_t index_length = index_dim[1]; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
auto block_width = ComputeBlockSize(index_length); | ||||||||||||||||||||||||||||||||
auto block_height = | ||||||||||||||||||||||||||||||||
ComputeBlockSize(index_length * batch_size) / block_width; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
dim3 block_dim(block_width, block_height); | ||||||||||||||||||||||||||||||||
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, | ||||||||||||||||||||||||||||||||
(batch_size + block_dim.y - 1) / block_dim.y); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
platform::GpuMemsetAsync(input_grad_data, 0, | ||||||||||||||||||||||||||||||||
sizeof(T) * input_length * batch_size, stream); | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里改成调用如下函数: Paddle/paddle/fluid/operators/trace_op.h Lines 219 to 221 in f89da4a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (index_type == framework::proto::VarType::INT64) { | ||||||||||||||||||||||||||||||||
const int64_t* index_data = index->data<int64_t>(); | ||||||||||||||||||||||||||||||||
index_kernel_grad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||||||||||||||||||||||||
index_data, input_grad_data, output_grad_data, index_length, | ||||||||||||||||||||||||||||||||
input_length, batch_size); | ||||||||||||||||||||||||||||||||
} else if (index_type == framework::proto::VarType::INT32) { | ||||||||||||||||||||||||||||||||
const int* index_data = index->data<int>(); | ||||||||||||||||||||||||||||||||
index_kernel_grad<T, int><<<grid_dim, block_dim, 0, stream>>>( | ||||||||||||||||||||||||||||||||
index_data, input_grad_data, output_grad_data, index_length, | ||||||||||||||||||||||||||||||||
input_length, batch_size); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op实现里面不用同步。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后续已删除 |
||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
} // namespace operators | ||||||||||||||||||||||||||||||||
} // namespace paddle | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
namespace ops = paddle::operators; | ||||||||||||||||||||||||||||||||
REGISTER_OP_CUDA_KERNEL( | ||||||||||||||||||||||||||||||||
index_sample, | ||||||||||||||||||||||||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int64_t>); | ||||||||||||||||||||||||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>); | ||||||||||||||||||||||||||||||||
REGISTER_OP_CUDA_KERNEL( | ||||||||||||||||||||||||||||||||
index_sample_grad, | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int64_t>); | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, float>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, double>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, int>, | ||||||||||||||||||||||||||||||||
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, | ||||||||||||||||||||||||||||||||
int64_t>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个实际没有用到?