Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] Migrate matmul kernel #48162

Merged
merged 50 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
b2c89e6
cleanup unused code
Silv3S Nov 8, 2022
1be88bc
unify is_int8 is_bfloat16
Silv3S Nov 8, 2022
4f82616
Simplify matmul_v2 FWD kernel
Silv3S Nov 8, 2022
f5375fd
remove RunKernel methods
Silv3S Nov 8, 2022
9c927fa
remove import namespace
Silv3S Nov 8, 2022
6dd70a1
remove headers
Silv3S Nov 8, 2022
f763164
Merge branch 'develop' into mkldnn_cleanup
Silv3S Nov 8, 2022
f613c3f
Merge branch 'PaddlePaddle:develop' into mkldnn_cleanup
Silv3S Nov 9, 2022
02392d3
clean fluid/phi cross imports
Silv3S Nov 9, 2022
a3c0c61
remove fluid axpy_handler
Silv3S Nov 9, 2022
fbf1605
delete fluid methods
Silv3S Nov 9, 2022
e8dcc47
Merge branch 'mkldnn_cleanup' into axpy_fluid_phi
Silv3S Nov 9, 2022
cc3b784
activations
Silv3S Nov 9, 2022
0994e24
OneDNNMemDesc
Silv3S Nov 9, 2022
6dbfddf
MKLDNNFormatForSize
Silv3S Nov 9, 2022
de56962
MatchShapeToLayout
Silv3S Nov 9, 2022
1bc636b
MKLDNNMemoryFormat
Silv3S Nov 9, 2022
387a16b
MKLDNNFormat
Silv3S Nov 9, 2022
7536353
ReorderMKLDNNHandler
Silv3S Nov 9, 2022
38427c2
to_void_cast
Silv3S Nov 9, 2022
6ff7998
Merge branch 'axpy_fluid_phi' into mkldnn_cleanup
Silv3S Nov 10, 2022
206afcb
review suggestions
Silv3S Nov 10, 2022
ee87e9c
interpolate
Silv3S Nov 10, 2022
6ca8717
Merge branch 'develop' into mkldnn_cleanup
Silv3S Nov 10, 2022
0c9ca31
remove fluid depedency
Silv3S Nov 14, 2022
d0d94d4
Merge branch 'develop' into mkldnn_cleanup
Silv3S Nov 14, 2022
4780111
init
Silv3S Nov 14, 2022
17684e0
ExecuteMatMulV2
Silv3S Nov 14, 2022
efb3932
rm fluid kernel
Silv3S Nov 15, 2022
2c968e7
Merge branch 'develop' into phi_matmul_grad_kernel
Silv3S Nov 15, 2022
2e531cb
matmul_grad
Silv3S Nov 15, 2022
650d136
remove mutable_data
Silv3S Nov 15, 2022
faea539
Merge branch 'PaddlePaddle:develop' into phi_matmul_grad_kernel
Silv3S Nov 15, 2022
8fe8e46
Merge branch 'PaddlePaddle:develop' into phi_matmul_grad_kernel
Silv3S Nov 16, 2022
37a7627
mul_grad
Silv3S Nov 16, 2022
f24c433
matmul fwd
Silv3S Nov 16, 2022
315498f
add extra attr
Silv3S Nov 18, 2022
21957ba
Merge branch 'develop' into phi_matmul_kernel
Silv3S Nov 18, 2022
c718fdb
Merge branch 'develop' into phi_matmul_kernel
Silv3S Nov 18, 2022
9cb4a13
temp disable passes
Silv3S Nov 18, 2022
a8e91fc
Merge branch 'develop' into phi_matmul_kernel
Silv3S Nov 21, 2022
b49dddb
re-enable passes
Silv3S Nov 21, 2022
dae9341
workaround for matmul+act
Silv3S Nov 22, 2022
a325beb
fix for matmul+eltwise_add
Silv3S Nov 22, 2022
cc8f393
Merge branch 'PaddlePaddle:develop' into phi_matmul_kernel
Silv3S Nov 22, 2022
fe384a1
fix typo
Silv3S Nov 23, 2022
6328e85
merge bugfix #48364
Silv3S Nov 25, 2022
65cf6df
Merge branch 'develop' into phi_matmul_kernel
Silv3S Nov 28, 2022
1c1c802
remove merge conflict
Silv3S Nov 28, 2022
f5e3bfe
Merge branch 'PaddlePaddle:develop' into phi_matmul_kernel
Silv3S Nov 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3230,6 +3230,30 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
VLOG(4) << "Done attributes";

// Clear All old attrs before add new attrs,
// because sometimes old attrs may be misused.
#if defined(PADDLE_WITH_MKLDNN)
if (phi::OneDNNContext::classof(dev_ctx)) {
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
one_dnn_ctx->ClearDnnAttr();
}
#endif

// Note(YuanRisheng): Now, we can't open code below.
// Because some unittest run OLD dygraph and ExtraAttr is not supported in OLD
// dygraph. So, here we use trick that dev_ctx is a global object. We can
// store ExtraAttr in static graph and when unittest run OLD dygraph, it can
// obtain these ExtraAttr. We can open this code when OLD dygraph is no longer
// used.
/*
#if defined(PADDLE_WITH_CUDA)
if(phi::GPUContext::classof(dev_ctx)) {
phi::GPUContext* gpu_dnn_ctx = static_cast<phi::GPUContext*>(dev_ctx);
gpu_dnn_ctx->ClearDnnAttr();
}
#endif
*/

// For compatible with Op with extra attrs for specific backend
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
auto& runtime_attrs = RuntimeAttrs();
Expand Down
18 changes: 5 additions & 13 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
}

template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) {
Expand Down Expand Up @@ -699,21 +699,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
REGISTER_OP_KERNEL(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
MatMulMKLDNNKernel<float>,
MatMulMKLDNNKernel<paddle::platform::bfloat16>,
MatMulMKLDNNKernel<int8_t>,
MatMulMKLDNNKernel<uint8_t>);

REGISTER_OP_KERNEL(matmul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulGradMKLDNNKernel<float>,
MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);

REGISTER_OP_KERNEL(matmul_v2,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
4 changes: 3 additions & 1 deletion paddle/fluid/operators/ops_extra_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"fuse_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_beta", ExtraAttrProperty::ONEDNN},
{"fuse_relu", ExtraAttrProperty::ONEDNN},
{"fused_output_scale", ExtraAttrProperty::ONEDNN},
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
{"fuse_with_relu", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
Expand Down Expand Up @@ -221,7 +222,8 @@ class ExtraInfoUtils {
std::unordered_map<std::string, std::vector<std::string>>
g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}},
{"conv2d_transpose", {"Bias"}},
{"conv2d_grad", {"Bias"}}};
{"conv2d_grad", {"Bias"}},
{"matmul_v2", {"ResidualData"}}};
std::vector<std::string> empty_extra_input_names_;
};

Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/gpu/gpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,8 @@ struct GPUContext::Impl {
dnn_attrs_[attr_name] = attr;
}

void ClearDnnAttr() { dnn_attrs_.clear(); }

// use one flag for all handles?
// they should be accessed consistently
bool owned_{false};
Expand Down Expand Up @@ -1042,4 +1044,6 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
return impl_->SetDnnAttr(attr_name, std::move(attr));
}

void GPUContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }

} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/backends/gpu/gpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class PADDLE_API GPUContext : public DeviceContext,
bool HasDnnAttr(const std::string& attr_name) const;
const Attribute& GetDnnAttr(const std::string& attr_name) const;
void SetDnnAttr(const std::string& attr_name, Attribute attr);
void ClearDnnAttr();

static const char* name() { return "GPUContext"; }

Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/onednn/onednn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ struct OneDNNContext::Impl {
dnn_attrs_[attr_name] = attr;
}

void ClearDnnAttr() { dnn_attrs_.clear(); }

bool HasDnnInput(const std::string& input_name) const {
return dnn_inputs_.count(input_name) != 0UL;
}
Expand Down Expand Up @@ -429,6 +431,8 @@ bool OneDNNContext::HasDnnInput(const std::string& input_name) const {
return impl_->HasDnnInput(input_name);
}

void OneDNNContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }

const DenseTensor* OneDNNContext::GetDnnInput(
const std::string& input_name) const {
return impl_->GetDnnInput(input_name);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/onednn/onednn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class OneDNNContext : public CPUContext {
const DenseTensor* GetDnnInput(const std::string& input_name) const;
void SetDnnInput(const std::string& input_name, const DenseTensor* input);

void ClearDnnAttr();

void SetInputsName(const TensorNameMap& inputs_name);

void SetOutputsName(const TensorNameMap& outputs_name);
Expand Down
21 changes: 14 additions & 7 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -1844,9 +1844,11 @@ class MatmulOneDNNHandler
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;

if (dev_ctx.HasDnnInput("ResidualData")) {
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
if (residual_data) {
auto residual_data_tz = vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
OneDNNGetDataType<OT>(),
Expand All @@ -1863,9 +1865,11 @@ class MatmulOneDNNHandler

AppendActivation(dev_ctx, post_operations);

if (dev_ctx.HasDnnAttr("fused_output_scale")) {
float scale_alpha =
PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"));
const float scale_alpha =
dev_ctx.HasDnnAttr("fused_output_scale")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"))
: 1.0f;
if (scale_alpha != 1.0f) {
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
Expand Down Expand Up @@ -1984,8 +1988,11 @@ void ExecuteMatmul(const OneDNNContext& dev_ctx,
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

if (dev_ctx.HasDnnInput("ResidualData")) {
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;

if (residual_data) {
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
Expand Down
164 changes: 164 additions & 0 deletions paddle/phi/kernels/onednn/matmul_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/matmul_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

DDim GetDimsForInput(const OneDNNContext &dev_ctx,
DDim input_dims,
std::string input_name) {
auto shape =
dev_ctx.HasDnnAttr("fused_reshape_" + input_name)
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_" + input_name))
: std::vector<int>();
auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name)
? PADDLE_GET_CONST(
std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_" + input_name))
: std::vector<int>();
if (!shape.empty() && !axis.empty()) {
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}

void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims,
DenseTensor *out,
const bool is_output_fused) {
if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
(*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
(*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
}
}

if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
(*y_bd_dims)[i] == 1,
true,
errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i,
(*x_bd_dims)[i],
i,
(*y_bd_dims)[i]));
(out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
out->Resize(make_ddim((out_dims)));
}
}

template <typename T, typename Context>
void MatmulKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
bool transpose_x,
bool transpose_y,
DenseTensor *out) {
if (dev_ctx.HasDnnAttr("head_number")) {
const auto head_number =
PADDLE_GET_CONST(int, dev_ctx.GetDnnAttr("head_number"));
PADDLE_ENFORCE_EQ(
head_number,
1,
errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
head_number));
}

constexpr bool is_int8 = funcs::is_int8<T>();
constexpr bool is_bfloat16 = funcs::is_bfloat16<T>();
const bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false;

bool fuse_relu = false;
if (dev_ctx.HasDnnAttr("fuse_activation")) {
auto act_type =
PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"));
if (act_type == "relu" || act_type == "relu6") {
fuse_relu = true;
}
}

auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X"));
auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y"));

int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);

std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);

CalculateMatrixDims(x_dims,
y_dims,
&x_bd_dims,
&y_bd_dims,
out,
funcs::IsOutputFused(dev_ctx));

if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
funcs::ExecuteMatmul<T, float>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (is_bfloat16) {
funcs::ExecuteMatmul<T, paddle::platform::bfloat16>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (fuse_relu) {
funcs::ExecuteMatmul<T, uint8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else {
funcs::ExecuteMatmul<T, int8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
}
}

} // namespace phi

PD_REGISTER_KERNEL(matmul,
OneDNN,
ONEDNN,
phi::MatmulKernel,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}