Skip to content

Commit

Permalink
Revert "[Phi] Move mul op kernel into phi (PaddlePaddle#40833)"
Browse files Browse the repository at this point in the history
This reverts commit 1b49181.
  • Loading branch information
chenwhql committed Jun 15, 2022
1 parent 2b5771c commit ffc4770
Show file tree
Hide file tree
Showing 23 changed files with 267 additions and 295 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h"

USE_OP_ITSELF(mul);
USE_OP(mul);
USE_OP(cinn_launch);
USE_OP_ITSELF(elementwise_add);
namespace paddle::framework {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) {
} // namespace paddle

USE_PASS(build_cinn_pass);
USE_OP_ITSELF(mul);
USE_OP(mul);
USE_OP_ITSELF(relu);
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(relu_grad);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,6 @@ TEST(CinnCompilerTest, Compile) {

USE_PASS(build_cinn_pass);
USE_PASS(graph_viz_pass);
USE_OP_ITSELF(mul);
USE_OP(mul);
USE_OP_ITSELF(relu);
USE_OP_ITSELF(elementwise_add);
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/tests/test_eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ TEST(test_var_helper, eager_var_helper) {
} // namespace imperative
} // namespace paddle

USE_OP_ITSELF(mul);
USE_OP(mul);
6 changes: 2 additions & 4 deletions paddle/fluid/imperative/tests/test_hooks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@

PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten_grad, CPU, ALL_LAYOUT);

namespace platform = paddle::platform;
namespace framework = paddle::framework;
Expand Down Expand Up @@ -269,7 +267,7 @@ TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) {
} // namespace imperative
} // namespace paddle

USE_OP_ITSELF(mul);
USE_OP_ITSELF(mul_grad);
USE_OP(mul);
USE_OP(mul_grad);
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(elementwise_add_grad);
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/tests/test_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,4 @@ TEST(test_layer, test_eager) {
} // namespace imperative
} // namespace paddle

USE_OP_ITSELF(mul);
USE_OP(mul);
8 changes: 2 additions & 6 deletions paddle/fluid/imperative/tests/test_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,10 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten_grad, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten_grad, GPU, ALL_LAYOUT);
#endif

namespace imperative = paddle::imperative;
Expand Down Expand Up @@ -603,8 +599,8 @@ TEST(test_tracer, eager_tracer) {
} // namespace imperative
} // namespace paddle

USE_OP_ITSELF(mul);
USE_OP_ITSELF(mul_grad);
USE_OP(mul);
USE_OP(mul_grad);
USE_OP_ITSELF(reduce_sum);
USE_OP_ITSELF(reduce_sum_grad);
USE_OP_ITSELF(elementwise_add);
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ TEST(fc_op, test) {
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP_ITSELF(mul);
USE_OP(mul);
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ TEST(MulOpConverter, main) {
} // namespace inference
} // namespace paddle

USE_OP_ITSELF(mul);
USE_OP(mul);
5 changes: 1 addition & 4 deletions paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License. */

#include <string>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace phi {
Expand Down Expand Up @@ -46,9 +46,6 @@ using dnnl::memory;
using dnnl::prop_kind;
using dnnl::stream;

constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;

template <typename XT, typename YT, typename OT>
class MulPrimitiveFactory {
public:
Expand Down
19 changes: 15 additions & 4 deletions paddle/fluid/operators/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/mul_op.h"

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Expand All @@ -32,9 +33,6 @@ namespace operators {
using framework::OpKernelType;
using framework::Tensor;

constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;

class MulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand Down Expand Up @@ -283,3 +281,16 @@ REGISTER_OPERATOR(mul_grad, ops::MulGradOp,
MulGradInferShapeFunctor);

REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp);

REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_CPU_KERNEL(
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulGradKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_CPU_KERNEL(
mul_grad_grad,
ops::MulDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
31 changes: 31 additions & 0 deletions paddle/fluid/operators/mul_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/mul_op.h"

#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
ops::MulKernel<plat::CUDADeviceContext, double>,
ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mul_grad, ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>,
ops::MulGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mul_grad_grad,
ops::MulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>);
207 changes: 207 additions & 0 deletions paddle/fluid/operators/mul_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;

template <typename DeviceContext, typename T>
class MulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
Tensor* z = context.Output<Tensor>("Out");
const Tensor x_matrix =
x->dims().size() > 2
? framework::ReshapeToMatrix(
*x, context.template Attr<int>("x_num_col_dims"))
: *x;
const Tensor y_matrix =
y->dims().size() > 2
? framework::ReshapeToMatrix(
*y, context.template Attr<int>("y_num_col_dims"))
: *y;

z->mutable_data<T>(context.GetPlace());
auto z_dim = z->dims();
if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}

auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);

blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
}
};

template <typename DeviceContext, typename T>
class MulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor&>(*y);
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));

Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0],
phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]});

auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));

if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}

auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;

// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_matrix = dy->dims().size() > 2
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
};

template <typename DeviceContext, typename T>
class MulDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto x_mat = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_mat = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor&>(*y);

const int m = phi::flatten_to_2d(x->dims(), x_num_col_dims)[0];
const int n = phi::flatten_to_2d(y->dims(), y_num_col_dims)[1];

auto* dout = ctx.Input<framework::LoDTensor>("DOut");
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize({m, n});

auto* ddx = ctx.Input<framework::LoDTensor>("DDX");
auto* ddy = ctx.Input<framework::LoDTensor>("DDY");

auto* dx = ctx.Output<framework::LoDTensor>("DX");
auto* dy = ctx.Output<framework::LoDTensor>("DY");
auto* ddout = ctx.Output<framework::LoDTensor>("DDOut");

Tensor ddout_mat;
if (ddout) {
ddout->set_lod(dout->lod());
// allocate and reshape ddout
ddout->mutable_data<T>(ctx.GetPlace());
ddout_mat.ShareDataWith(*ddout);
ddout_mat.Resize({m, n});
}

auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
// a flag to specify whether ddout value has been set, if flag
// is false, MatMul beta should be 0 to set ddout, if flag is
// true, MatMul beta should be 1 to add result to ddout.
bool ddout_flag = false;
if (ddx) {
auto ddx_mat = ddx->dims().size() > 2
? framework::ReshapeToMatrix(*ddx, x_num_col_dims)
: static_cast<const Tensor&>(*ddx);

// dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N
if (dy) {
dy->set_lod(y->lod());
// allocate and reshape dy
dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_mat = dy->dims().size() > 2
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat);
}
// ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N
if (ddout) {
blas.MatMul(ddx_mat, false, y_mat, false, static_cast<T>(1.0),
&ddout_mat, static_cast<T>(ddout_flag));
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = ddy->dims().size() > 2
? framework::ReshapeToMatrix(*ddy, y_num_col_dims)
: static_cast<const Tensor&>(*ddy);
// dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K
if (dx) {
dx->set_lod(x->lod());
// allocate and reshape dx
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_mat = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat);
}
// ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N
if (ddout) {
blas.MatMul(x_mat, false, ddy_mat, false, static_cast<T>(1.0),
&ddout_mat, static_cast<T>(ddout_flag));
}
}
}
};

} // namespace operators
} // namespace paddle
Loading

0 comments on commit ffc4770

Please sign in to comment.