Skip to content

Commit

Permalink
refactoring matmul_v2 mkldnn hierarchy (PaddlePaddle#37622)
Browse files Browse the repository at this point in the history
* refactoring matmul hierarchy

* review fix

* review fix

* review_FIX-part2
  • Loading branch information
sfraczek authored and Zjq9409 committed Dec 10, 2021
1 parent 41c5f2d commit fea46a6
Showing 1 changed file with 71 additions and 64 deletions.
135 changes: 71 additions & 64 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;
using paddle::framework::vectorize;
using paddle::framework::make_ddim;
using paddle::framework::GradVarName;
using paddle::framework::make_ddim;
using paddle::framework::vectorize;

template <typename T>
class MatMulV2MKLDNNHandler
Expand Down Expand Up @@ -123,45 +123,58 @@ class MatMulV2MKLDNNHandler
}
};

template <typename T>
class MatMulV2MKLDNNKernel
: public paddle::operators::MatMulGradMKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
bool IsOutputFused(const ExecutionContext& ctx) {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}

float ComputeOutputScale(const ExecutionContext& ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
return scale_out / (scale_x * scale_y);
}

protected:
void ExecuteMatMul(const ExecutionContext& ctx,
template <typename T>
void ExecuteMatMulV2(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y,
IsOutputFused(ctx));
int execution_number = 0) {
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y,
IsOutputFused(ctx));

const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);

auto matmul_p = handler.AcquireForwardPrimitive();
auto matmul_p = handler.AcquireForwardPrimitive();

std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();

auto format = paddle::platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(format);
}
auto format = paddle::platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(format);
}

template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }

private:
void CalculateMatrixDims(const ExecutionContext& ctx,
Expand Down Expand Up @@ -207,13 +220,6 @@ class MatMulV2MKLDNNKernel
}
}

bool IsOutputFused(const ExecutionContext& ctx) const {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out =
ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}

void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
Expand All @@ -237,13 +243,14 @@ class MatMulV2MKLDNNKernel
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out);

ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims,
trans_x, y, y_bd_dims, trans_y, out, out_dims);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_bd_dims, trans_x, y, y_bd_dims, trans_y, out,
out_dims);
}
};

template <typename T>
class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }

Expand Down Expand Up @@ -316,7 +323,7 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
paddle::operators::MatMulGradMKLDNNKernel<T>::Compute(ctx);
matmul_v1_grad_mkldnn_kernel.Compute(ctx);
return;
}

Expand All @@ -342,33 +349,29 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy_bd_dims);

if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, true, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
true, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
} else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, false, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims,
false, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
false, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
} else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp,
dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp, dy_bd_dims,
2);
} else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, true, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp, dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims,
true, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2);
}

if (x_dims != dx_bd_dims) {
Expand All @@ -389,8 +392,12 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
}

private:
paddle::operators::MatMulGradMKLDNNKernel<T> matmul_v1_grad_mkldnn_kernel;
};
} // anonymous namespace

namespace ops = paddle::operators;

REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
Expand Down

0 comments on commit fea46a6

Please sign in to comment.