-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Add Sub Mul Max Min Pow binary functors in elementwise system #33050
Changes from 12 commits
e4d29f3
3f243ed
bc12b1b
8290346
49c221b
dd24d12
173ee57
0a7bfef
07b3797
a16ba39
9bca0af
b5182f1
9d46543
74e4179
656ac99
585566f
b9c5ea5
25d290e
0e4a011
d9c70ec
ce5a717
950965b
1f72b51
90a0b29
5b146cd
f5a2ce7
cd40092
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,8 +35,8 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> { | |
|
||
z->mutable_data<T>(ctx.GetPlace()); | ||
int axis = ctx.Attr<int>("axis"); | ||
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis, | ||
MaxFunctor<T>(), z); | ||
ElementwiseComputeEx<MaxFunctor<T>, platform::CPUDeviceContext, T>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是一个通用的实现,不是一个特化的实现,不要改这里。后面如果想删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 按照建议修改 |
||
ctx, x, y, axis, MaxFunctor<T>(), z); | ||
} | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and | |
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" | ||
#include "paddle/fluid/platform/complex128.h" | ||
#include "paddle/fluid/platform/complex64.h" | ||
|
@@ -25,37 +26,60 @@ namespace paddle { | |
namespace operators { | ||
|
||
template <typename T> | ||
struct SameDimsElemwiseMul<platform::CUDADeviceContext, T> { | ||
void operator()(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
framework::Tensor* z) { | ||
MulRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>()); | ||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); | ||
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, | ||
x->numel()); | ||
for_range(functor); | ||
struct CudaMulFunctor { | ||
inline HOSTDEVICE T operator()(const T* args) const { | ||
return args[0] * args[1]; | ||
} | ||
}; | ||
|
||
template <> | ||
struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> { | ||
void operator()(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
framework::Tensor* z) { | ||
auto size = x->numel(); | ||
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / | ||
PADDLE_CUDA_THREAD_SIZE, | ||
1); | ||
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); | ||
const half* x2 = | ||
reinterpret_cast<const half*>(x->data<platform::float16>()); | ||
const half* y2 = | ||
reinterpret_cast<const half*>(y->data<platform::float16>()); | ||
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>()); | ||
SameDimsElemwiseMulCUDAKernel<<< | ||
grid_size, block_size, 0, | ||
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>( | ||
x2, y2, z2, size); | ||
template <typename T> | ||
class ElementwiseMulKernel<platform::CUDADeviceContext, T> | ||
: public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto x_var = ctx.InputVar("X"); | ||
PADDLE_ENFORCE_EQ(x_var != nullptr, true, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个检查可以用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 按照建议修改 |
||
platform::errors::InvalidArgument( | ||
"Cannot get input Variable X, Variable name = %s.", | ||
ctx.InputName("X"))); | ||
auto* y = ctx.Input<framework::LoDTensor>("Y"); | ||
framework::Tensor x, *z; | ||
|
||
if (x_var->IsType<framework::SelectedRows>()) { | ||
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true, | ||
platform::errors::InvalidArgument( | ||
"For elementwise_op, if X is Sparse, Y must be " | ||
"scalar. But reveived the size of Y = %s.", | ||
y->dims().size())); | ||
auto& x_sele = x_var->Get<framework::SelectedRows>(); | ||
auto out_sele = ctx.Output<framework::SelectedRows>("Out"); | ||
x = x_sele.value(); | ||
out_sele->set_rows(x_sele.rows()); | ||
out_sele->set_height(x_sele.height()); | ||
out_sele->mutable_value()->Resize(x_sele.value().dims()); | ||
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type()); | ||
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value(); | ||
} else if (x_var->IsType<framework::LoDTensor>()) { | ||
x = x_var->Get<framework::LoDTensor>(); | ||
z = ctx.Output<framework::LoDTensor>("Out"); | ||
} else { | ||
PADDLE_THROW(platform::errors::InvalidArgument( | ||
"X's type[%s] is not supported by elementwise_op. X's type should be " | ||
"LoDTensor or SelectedRows.", | ||
framework::ToTypeName(x_var->Type()))); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要复制粘贴大段代码,L41 - L70写个函数封装一下。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L41 - L70也可以封装到 |
||
z->mutable_data<T>(ctx.GetPlace()); | ||
|
||
int axis = ctx.Attr<int>("axis"); | ||
axis = axis == -1 ? std::abs(x.dims().size() - y->dims().size()) : axis; | ||
|
||
std::vector<const framework::Tensor*> ins = {&x, y}; | ||
std::vector<framework::Tensor*> outs = {z}; | ||
const auto& cuda_ctx = | ||
ctx.template device_context<platform::CUDADeviceContext>(); | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( | ||
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>()); | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行axis的换算也可以放到
LaunchElementwiseCudaKernel
里面?