Skip to content

Commit

Permalink
refine warpper
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Jun 2, 2021
1 parent f5a2ce7 commit cd40092
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 63 deletions.
47 changes: 23 additions & 24 deletions paddle/fluid/operators/controlflow/compare_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {

#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \
template <typename T, typename Enable = void> \
struct Func##Functor { \
struct func { \
using ELEMENT_TYPE = T; \
inline HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
};

DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=)
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT

template <typename T>
Expand Down Expand Up @@ -67,10 +67,12 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
ctx, ins, &outs, functor);
cuda_ctx, ins, &outs, axis, functor);
}
};

Expand All @@ -79,19 +81,16 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>

#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<double>, void>);
op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);

REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual)
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor)
#undef REGISTER_CUDA_COMPARE_KERNEL
7 changes: 5 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
PackTensorsIntoVector<T>(ctx, &ins, &outs);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, CudaAddFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
}
};

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;

const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, axis, CudaMaxFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaMaxFunctor<T>());
}
};

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_min_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ class ElementwiseMinKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
PackTensorsIntoVector<T>(ctx, &ins, &outs);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, CudaMinFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaMinFunctor<T>());
}
};

Expand Down
35 changes: 23 additions & 12 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,25 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int axis = -1;
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_EQ(x_var != nullptr, true,
platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
PADDLE_ENFORCE_NOT_NULL(
x_var, platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
auto* y = ctx.Input<framework::LoDTensor>("Y");

framework::Tensor x, *z;
if (x_var->IsType<framework::SelectedRows>()) {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

if (x_var->IsType<framework::LoDTensor>()) {
x = x_var->Get<framework::LoDTensor>();
z = ctx.Output<framework::LoDTensor>("Out");
axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
} else if (x_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be "
Expand All @@ -58,21 +68,22 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
} else if (x_var->IsType<framework::LoDTensor>()) {
x = x_var->Get<framework::LoDTensor>();
z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
outs.emplace_back(z);
ins.emplace_back(&x);
ins.emplace_back(y);

axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
z->mutable_data<T>(ctx.GetPlace());

std::vector<const framework::Tensor*> ins = {&x, y};
std::vector<framework::Tensor*> outs = {z};
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, CudaMulFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>());
}
};

Expand Down
14 changes: 3 additions & 11 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel(

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel(
const framework::ExecutionContext &ctx,
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) {
std::vector<int> dims_size;
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}
const auto &cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
func);
} else {
int axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
axis, func);
}
Expand Down
16 changes: 10 additions & 6 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,24 @@ namespace operators {
* To pack the input and output tnesors into vector for
* LaunchElementwiseCudaKernel
*/
template <typename T>
void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
template <typename OutT>
int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
int axis = -1;
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
ins->emplace_back(x);
z->mutable_data<OutT>(ctx.GetPlace());
outs->emplace_back(z);
ins->emplace_back(x);

if (y != nullptr) {
ins->emplace_back(y);
axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis;
}
return axis;
}

/*
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_pow_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ class ElementwisePowKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
PackTensorsIntoVector<T>(ctx, &ins, &outs);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, CudaPowFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaPowFunctor<T>());
}
};

Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ class ElementwiseSubKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
PackTensorsIntoVector<T>(ctx, &ins, &outs);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, CudaSubFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaSubFunctor<T>());
}
};

Expand Down

0 comments on commit cd40092

Please sign in to comment.