Skip to content

Commit

Permalink
Merge branch 'matmul_v2_grad' into squeeze2_op
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jul 19, 2021
2 parents 6e3f767 + 2104d0d commit 06fcf67
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 51 deletions.
24 changes: 18 additions & 6 deletions paddle/fluid/operators/matmul_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel {
}

std::vector<int64_t> new_dims;
if (ndims_x >= ndims_y) {
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else {
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
Expand Down Expand Up @@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace());
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,11 +582,12 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
: FoldFirstAndLastDims<T>(dev_ctx, y);
}

float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;

MatMulMKLDNNHandler<T> handler(
dev_ctx, engine, ctx.GetPlace(), &x_combined, trans_x, &y_combined,
trans_y, out, ctx.Attr<float>("alpha"),
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
trans_y, out, alpha, ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));

const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
Expand Down Expand Up @@ -620,10 +621,15 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));

bool transpose_x = ctx.Attr<bool>("transpose_X");
bool transpose_y = ctx.Attr<bool>("transpose_Y");
bool transpose_x = ctx.HasAttr("transpose_X")
? ctx.Attr<bool>("transpose_X")
: ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.HasAttr("transpose_Y")
? ctx.Attr<bool>("transpose_Y")
: ctx.Attr<bool>("trans_y");

ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);

framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
Expand Down Expand Up @@ -665,11 +671,13 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
dx->set_format(x.format());
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
dy->set_format(y.format());
}
}
}
Expand Down
223 changes: 196 additions & 27 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"

namespace paddle {
namespace operators {
Expand All @@ -35,14 +32,17 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
std::vector<int64_t>& x_dims, bool trans_x,
std::vector<int64_t>& y_dims, bool trans_y,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, x_dims, uniq_name)) {
platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);

const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
Expand Down Expand Up @@ -104,10 +104,43 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
};

template <typename T>
class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
class MatMulV2MKLDNNKernel : public MatMulGradMKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }

protected:
void ExecuteMatMul(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));

const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);

auto matmul_p = handler.AcquireForwardPrimitive();

std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();

out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
}

private:
void CalculateMatrixDims(const ExecutionContext& ctx,
const std::vector<int64_t>& x_dims,
Expand All @@ -117,13 +150,19 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t>& out_dims, Tensor* out) const {
if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
x_bd_dims[2] = x_dims[1];
x_bd_dims[1] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
y_bd_dims[2] = y_dims[1];
y_bd_dims[1] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i];
Expand Down Expand Up @@ -168,30 +207,160 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out);

MatMulV2MKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(),
x_bd_dims, trans_x, y_bd_dims, trans_y,
ctx.InputName("X"));
ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims,
trans_x, y, y_bd_dims, trans_y, out, out_dims);
}
};

const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
template <typename T>
class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }

auto matmul_p = handler.AcquireForwardPrimitive();
private:
void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp,
Tensor* dy_tmp,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& dy_dims,
std::vector<int64_t>& dx_bd_dims,
std::vector<int64_t>& dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
dx_bd_dims[i] = dy_dims[i];
} else {
dy_bd_dims[i] = dx_dims[i];
}
}
}

std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
dx_tmp->Resize(framework::make_ddim(dx_bd_dims));
dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(framework::make_ddim(dy_bd_dims));
dy_tmp->mutable_data<T>(ctx.GetPlace());
}

auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);

auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);

std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reduction_p = handler.AcquireForwardPrimitive();

reduction_p->execute(astream, reduction_args);
astream.wait();
}

out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();

auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");

auto x_dims = framework::vectorize(x->dims());
auto y_dims = framework::vectorize(y->dims());

bool is_broadcast = true;
if (x_dims.size() <= 2 || y_dims.size() <= 2) {
is_broadcast = false;
} else if (x_dims.size() != y_dims.size()) {
is_broadcast = true;
} else {
is_broadcast =
!std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
}

// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
MatMulGradMKLDNNKernel<T>::Compute(ctx);
return;
}

auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));

bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto dout_dims = framework::vectorize(dout->dims());

int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);

// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
Tensor dx_tmp, dy_tmp;

std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);

CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims,
dy_bd_dims);

if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, true, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
} else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, false, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, true, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
}

if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims);
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims);
} else {
*dy = std::move(dy_tmp);
}

dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(x->format());
dy->set_layout(framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
}
};

} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
Expand All @@ -200,6 +369,6 @@ REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulV2MKLDNNKernel<float>,
ops::MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);

// REGISTER_OP_KERNEL(matmul_grad_v2, MKLDNN, ::paddle::platform::CPUPlace,
// ops::MatMulV2GradMKLDNNKernel<float>,
// ops::MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulV2GradMKLDNNKernel<float>,
ops::MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
Loading

0 comments on commit 06fcf67

Please sign in to comment.