Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cuda kernel for index_sample. #30380

Merged
merged 10 commits into from
Feb 3, 2021
223 changes: 215 additions & 8 deletions paddle/fluid/operators/index_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,225 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_sample_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个实际没有用到?

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

template <typename T, typename IndexT = int>
__global__ void index_kernel(const IndexT* p_index, const T* p_input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码需遵循Google C++编程风格,函数命名为AxxBxx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

编码规范确实会修改,下一次PR中这个问题会被处理掉。

T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • p_index -> index, p_input -> input, p_output -> output,感觉这里不需要从命名上特意强调这是个ptr。
  • stride_index、stride_input、height,这几个参数我有点对应不上,变量命名能否更直观一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

指针加 "p_" 前缀是长期保持的习惯,后续修改成与paddle贴合的命名规范。

int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * stride_index + ix;
int tid_x = iy * stride_input + ix;
int tid_y = iy * stride_index + ix;

if (ix < stride_index & iy < height) {
IndexT idx = p_index[tid];
p_output[tid_y] = p_input[tid_x - ix + idx];
}
}

template <typename T, typename IndexT = int>
__global__ void index_kernel_grad(const IndexT* p_index, T* p_input,
const T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从实际含义上来讲:p_index -> index,p_input -> in_grad,p_output -> out_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * stride_index + ix;
int tid_y = iy * stride_input + ix;

if (ix < stride_index & iy < height) {
for (int i = 0; i < stride_index; ++i) {
if (ix == i) {
IndexT idx = p_index[tid];
T tmp1 = p_output[tid];
T tmp2 = p_input[tid_y - ix + idx];
p_input[tid_y - ix + idx] = tmp1 + tmp2;
}
}
}
}

template <typename DeviceContext, typename T>
class IndexSampleCUDAKernel : public framework::OpKernel<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议改成特化index_sample.h中IndexSampleKernel类的形式。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有改?我是建议改成如下方式:

template <typename T>
class SumKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:

这样L92的检查就可以去掉了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经按要求修改

public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("X");
auto* index = ctx.Input<LoDTensor>("Index");
auto* output = ctx.Output<LoDTensor>("Out");

const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));

auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline static int RoundToPowerOfTwo(int dim) {
if (dim > 512) {
return 1024;
} else if (dim > 256) {
return 512;
} else if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}

可使用这个函数代替吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的,这块的写法非常不美观,肯定替换掉。

const auto* in_data = input->data<T>();
auto* out_data = output->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();

auto input_dim = input->dims();
auto index_dim = index->dims();
size_t batch_size = input_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];

auto block_width = ComputeBlockSize(index_length);
int block_height =
ComputeBlockSize(index_length * batch_size) / block_width;

dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • cuda并行方案怎么设计的,在PR描述里面补充下。
  • op benchmark里面可能要补充下配置,当前只有1个index_dim=1的配置,最好补充下index_dim>1的配置。
  • 另外看看单测里面有没有index_dim>1的配置


if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_kernel<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
index_kernel<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op实现里面不用同步。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除同步操作

}
};

template <typename DeviceContext, typename T>
class IndexSampleGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* output_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* index = ctx.Input<LoDTensor>("Index");

const auto* output_grad_data = output_grad->data<T>();
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());

const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个检查可以删掉了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除检查判断


auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
};
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();

auto input_num = input_grad->numel();
auto input_dim = input_grad->dims();
auto index_dim = index->dims();
size_t batch_size = index_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];

auto block_width = ComputeBlockSize(index_length);
auto block_height =
ComputeBlockSize(index_length * batch_size) / block_width;

dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);

platform::GpuMemsetAsync(input_grad_data, 0,
sizeof(T) * input_length * batch_size, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里改成调用如下函数:

math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, d_x, static_cast<T>(0.0));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_kernel_grad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
index_kernel_grad<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op实现里面不用同步。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续已删除

}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
index_sample,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_sample_grad,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);