diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 8ac81596a36d3..d39eac0759cdb 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel { } std::vector new_dims; - if (ndims_x >= ndims_y) { + if (ndims_x > ndims_y) { new_dims.assign(dims_x.begin(), dims_x.end() - 2); - } else { + } else if (ndims_x < ndims_y) { new_dims.assign(dims_y.begin(), dims_y.end() - 2); + } else { + new_dims.reserve(ndims_x); + for (size_t i = 0; i < ndims_x - 2; ++i) { + new_dims.push_back(std::max(dims_x[i], dims_y[i])); + } } if (!x_broadcasted) { new_dims.push_back(M); @@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto out_grad_name = framework::GradVarName("Out"); - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), - ctx.GetPlace()); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h similarity index 97% rename from paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h index 2b3496359b0c6..9e90fe805d27e 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h @@ -582,11 +582,12 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel { : FoldFirstAndLastDims(dev_ctx, y); } + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + MatMulMKLDNNHandler handler( dev_ctx, engine, ctx.GetPlace(), &x_combined, trans_x, &y_combined, - trans_y, out, ctx.Attr("alpha"), - ctx.InputName(framework::GradVarName("Out")) + - std::to_string(execution_number)); + trans_y, out, alpha, ctx.InputName(framework::GradVarName("Out")) + + std::to_string(execution_number)); const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); @@ -620,10 +621,15 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - bool transpose_x = ctx.Attr("transpose_X"); - bool transpose_y = ctx.Attr("transpose_Y"); + bool transpose_x = ctx.HasAttr("transpose_X") + ? ctx.Attr("transpose_X") + : ctx.Attr("trans_x"); + bool transpose_y = ctx.HasAttr("transpose_Y") + ? ctx.Attr("transpose_Y") + : ctx.Attr("trans_y"); ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; if (dx) { dx_dims = dx->dims(); @@ -665,11 +671,13 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel { if (dx) { if (dx_dims != x.dims()) { dx->Resize(dx_dims); + dx->set_format(x.format()); } } if (dy) { if (dy_dims != y.dims()) { dy->Resize(dy_dims); + dy->set_format(y.format()); } } } diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 50afd417170e0..73397786f57b9 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" namespace paddle { namespace operators { @@ -35,14 +32,17 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT { public: MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, - std::vector& x_dims, bool trans_x, - std::vector& y_dims, bool trans_y, + const std::vector& x_org_dims, bool trans_x, + const std::vector& y_org_dims, bool trans_y, const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, - platform::CreateKey(dev_ctx, x_dims, uniq_name)) { + platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) { if (!this->isCached()) { // M X K * K X N + std::vector x_dims(x_org_dims); + std::vector y_dims(y_org_dims); + const int MB_idx = x_dims.size() - 3; const int H_idx = x_dims.size() - 2; const int W_idx = x_dims.size() - 1; @@ -104,10 +104,43 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT { }; template -class MatMulV2MKLDNNKernel : public framework::OpKernel { +class MatMulV2MKLDNNKernel : public MatMulGradMKLDNNKernel { public: void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } + protected: + void ExecuteMatMul(const ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine onednn_engine, + platform::Place cpu_place, const Tensor* x, + std::vector& x_dims, bool trans_x, + const Tensor* y, std::vector& y_dims, + bool trans_y, Tensor* out, std::vector& out_dims, + int execution_number = 0) const { + MatMulV2MKLDNNHandler handler( + dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims, + trans_y, ctx.InputName("X") + std::to_string(execution_number)); + + const auto src_memory_p = handler.AcquireSrcMemory(x); + const auto weights_memory_p = handler.AcquireWeightsMemory(y); + const auto dst_memory_p = handler.AcquireDstMemory(out); + + auto matmul_p = handler.AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = MKLDNNDeviceContext::tls().get_stream(); + matmul_p->execute(astream, matmul_args); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format( + GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims))); + } + private: void CalculateMatrixDims(const ExecutionContext& ctx, const std::vector& x_dims, @@ -117,6 +150,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { std::vector& out_dims, Tensor* out) const { if (x_dims.size() == 1) { x_bd_dims[x_bd_dims.size() - 1] = x_dims[0]; + } else if (x_dims.size() == 2) { + x_bd_dims[2] = x_dims[1]; + x_bd_dims[1] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { x_bd_dims[i] = x_dims[i]; @@ -124,6 +160,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { } if (y_dims.size() == 1) { y_bd_dims[x_bd_dims.size() - 2] = y_dims[0]; + } else if (y_dims.size() == 2) { + y_bd_dims[2] = y_dims[1]; + y_bd_dims[1] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { y_bd_dims[i] = y_dims[i]; @@ -168,30 +207,160 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims, out); - MatMulV2MKLDNNHandler handler(dev_ctx, onednn_engine, ctx.GetPlace(), - x_bd_dims, trans_x, y_bd_dims, trans_y, - ctx.InputName("X")); + ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims, + trans_x, y, y_bd_dims, trans_y, out, out_dims); + } +}; - const auto src_memory_p = handler.AcquireSrcMemory(x); - const auto weights_memory_p = handler.AcquireWeightsMemory(y); - const auto dst_memory_p = handler.AcquireDstMemory(out); +template +class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { + public: + void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } - auto matmul_p = handler.AcquireForwardPrimitive(); + private: + void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp, + Tensor* dy_tmp, + const std::vector& dx_dims, + const std::vector& dy_dims, + std::vector& dx_bd_dims, + std::vector& dy_bd_dims) const { + for (size_t i = 0; i < dx_dims.size() - 2; ++i) { + if (dx_dims[i] != dy_dims[i]) { + if (dx_dims[i] == 1) { + dx_bd_dims[i] = dy_dims[i]; + } else { + dy_bd_dims[i] = dx_dims[i]; + } + } + } - std::unordered_map matmul_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; + dx_tmp->Resize(framework::make_ddim(dx_bd_dims)); + dx_tmp->mutable_data(ctx.GetPlace()); + dy_tmp->Resize(framework::make_ddim(dy_bd_dims)); + dy_tmp->mutable_data(ctx.GetPlace()); + } - auto& astream = MKLDNNDeviceContext::tls().get_stream(); - matmul_p->execute(astream, matmul_args); + void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine onednn_engine, + const Tensor* dx_tmp, Tensor* dx, + std::vector dx_dims) const { + platform::ReductionMKLDNNHandler handler( + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, + ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims); + + auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); + auto dst_memory_p = handler.AcquireDstMemory(dx); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + auto reduction_p = handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, reduction_args); astream.wait(); + } - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format( - GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims))); + void RunKernel(const ExecutionContext& ctx) const { + const auto& dev_ctx = ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + + auto x_dims = framework::vectorize(x->dims()); + auto y_dims = framework::vectorize(y->dims()); + + bool is_broadcast = true; + if (x_dims.size() <= 2 || y_dims.size() <= 2) { + is_broadcast = false; + } else if (x_dims.size() != y_dims.size()) { + is_broadcast = true; + } else { + is_broadcast = + !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2, + y_dims.cbegin()); + } + + // if no broadcasting is needed, we can simply use matmul's grad and avoid + // using reduce_sum + if (!is_broadcast) { + MatMulGradMKLDNNKernel::Compute(ctx); + return; + } + + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + bool trans_x = ctx.Attr("trans_x"); + bool trans_y = ctx.Attr("trans_y"); + auto dout_dims = framework::vectorize(dout->dims()); + + int ndims = std::max(x->dims().size(), y->dims().size()); + ndims = std::max(ndims, 3); + + // in broadcasting scenario new memory is required because + // reduce sum must be calculated upon broadcasted dims + Tensor dx_tmp, dy_tmp; + + std::vector dx_bd_dims(x_dims); + std::vector dy_bd_dims(y_dims); + + CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims, + dy_bd_dims); + + if (trans_x && trans_y) { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, + y_dims, true, dout, dout_dims, true, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims, + 2); + } else if (trans_x) { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, + y_dims, false, dout, dout_dims, true, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, + x_dims, false, dout, dout_dims, false, &dy_tmp, + dy_bd_dims, 2); + } else if (trans_y) { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, false, y, y_dims, false, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, true, x, x_dims, false, &dy_tmp, + dy_bd_dims, 2); + } else { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, false, y, y_dims, true, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, + x_dims, true, dout, dout_dims, false, &dy_tmp, + dy_bd_dims, 2); + } + + if (x_dims != dx_bd_dims) { + ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx, + x_dims); + } else { + *dx = std::move(dx_tmp); + } + if (y_dims != dy_bd_dims) { + ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy, + y_dims); + } else { + *dy = std::move(dy_tmp); + } + + dx->set_layout(framework::DataLayout::kMKLDNN); + dx->set_format(x->format()); + dy->set_layout(framework::DataLayout::kMKLDNN); + dy->set_format(y->format()); } }; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; @@ -200,6 +369,6 @@ REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, ops::MatMulV2MKLDNNKernel, ops::MatMulV2MKLDNNKernel); -// REGISTER_OP_KERNEL(matmul_grad_v2, MKLDNN, ::paddle::platform::CPUPlace, -// ops::MatMulV2GradMKLDNNKernel, -// ops::MatMulV2GradMKLDNNKernel); +REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::MatMulV2GradMKLDNNKernel, + ops::MatMulV2GradMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index ea06e2c447233..3b9d817522561 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +from functools import reduce import numpy as np from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 @@ -23,14 +24,12 @@ import paddle.fluid as fluid import paddle.fluid.framework as framework -paddle.enable_static() - -def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): +def reference_matmul(X, Y, trans_x=False, trans_y=False): """Reference forward implementation using np.matmul.""" # np.matmul does not support the transpose flags, so we manually # transpose X and Y appropriately. - if transpose_X: + if trans_x: if X.ndim == 1: X = X.reshape((X.size, )) elif X.ndim == 2: @@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): dim = [i for i in range(len(X.shape))] dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] X = np.transpose(X, tuple(dim)) - if transpose_Y: + if trans_y: if Y.ndim == 1: Y = Y.reshape((Y.size, )) else: @@ -144,8 +143,8 @@ def config(self): class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (1, 1, 12, 4) - self.y_shape = (1, 2, 4, 12) + self.x_shape = (2, 1, 12, 9) + self.y_shape = (1, 3, 9, 12) self.trans_x = False self.trans_y = False @@ -170,8 +169,8 @@ def config(self): class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3( TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (2, 2, 5, 4) - self.y_shape = (2, 2, 5, 3) + self.x_shape = (2, 2, 7, 4) + self.y_shape = (2, 2, 7, 5) self.trans_x = True self.trans_y = False @@ -179,7 +178,7 @@ def config(self): class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp( TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (3, 1, 6, 5) + self.x_shape = (3, 1, 6, 7) self.y_shape = (1, 2, 6, 9) self.trans_x = True self.trans_y = False @@ -203,8 +202,8 @@ def config(self): class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (2, 1, 40) - self.y_shape = (40) + self.x_shape = (2, 1, 100) + self.y_shape = (100) self.trans_x = False self.trans_y = False @@ -245,6 +244,8 @@ def set_inputs(self, x, y): 'X': convert_float_to_uint16(x), 'Y': convert_float_to_uint16(y) } + self.x_fp32 = x + self.y_fp32 = y def set_dtype_attr(self): self.attrs['mkldnn_data_type'] = "bfloat16" @@ -253,7 +254,99 @@ def test_check_output(self): self.check_output_with_place(core.CPUPlace()) def test_check_grad(self): - pass + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X", "Y"], + "Out", + user_defined_grads=[self.dx, self.dy], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) + + def matmul_grad(self, x, transpose_x, y, transpose_y): + x = np.transpose( + x, self.shape_transpose_axes[x.ndim]) if transpose_x else x + y = np.transpose( + y, self.shape_transpose_axes[y.ndim]) if transpose_y else y + + return np.matmul(x, y) + + def calculate_grads(self): + self.shape_transpose_axes = { + 2: [1, 0], + 3: [0, 2, 1], + 4: [0, 1, 3, 2], + 5: [0, 1, 2, 4, 3] + } + + # expand vector so it will be a valid matrix for multiplication + if self.x_fp32.ndim == 1: + self.x_fp32 = np.expand_dims(self.x_fp32, axis=0) + if self.y_fp32.ndim == 1: + self.y_fp32 = np.expand_dims(self.y_fp32, axis=1) + + x_transpose_axes = self.shape_transpose_axes[self.x_fp32.ndim] + y_transpose_axes = self.shape_transpose_axes[self.y_fp32.ndim] + + x = np.transpose(self.x_fp32, x_transpose_axes) if self.attrs[ + 'trans_x'] is True else self.x_fp32 + y = np.transpose(self.y_fp32, y_transpose_axes) if self.attrs[ + 'trans_y'] is True else self.y_fp32 + + dout = np.matmul(x, y) + + x_shape = x.shape + y_shape = y.shape + + if x.ndim <= 2 or y.ndim <= 2: + is_broadcast = False + elif x.ndim != y.ndim: + is_broadcast = True + else: + is_broadcast = x.shape[0:-2] != y.shape[0:-2] + + if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True: + self.dx = self.matmul_grad(self.y_fp32, True, dout, True) + self.dy = self.matmul_grad(dout, True, self.x_fp32, True) + elif self.attrs['trans_x'] is True and self.attrs[ + 'trans_y'] is False: + self.dx = self.matmul_grad(self.y_fp32, False, dout, True) + self.dy = self.matmul_grad(self.x_fp32, False, dout, False) + elif self.attrs['trans_x'] is False and self.attrs[ + 'trans_y'] is True: + self.dx = self.matmul_grad(dout, False, self.y_fp32, False) + self.dy = self.matmul_grad(dout, True, self.x_fp32, False) + else: + self.dx = self.matmul_grad(dout, False, self.y_fp32, True) + self.dy = self.matmul_grad(self.x_fp32, True, dout, False) + + if is_broadcast: + x_reduce_axis = [] + y_reduce_axis = [] + for index, ( + first, second + ) in enumerate(zip(x_shape[0:-2], self.dx.shape[0:-2])): + if first != second: + x_reduce_axis.append(index) + + for index, ( + first, second + ) in enumerate(zip(y_shape[0:-2], self.dy.shape[0:-2])): + if first != second: + y_reduce_axis.append(index) + + if x_reduce_axis: + self.dx = self.dx.sum(axis=tuple(x_reduce_axis), + keepdims=True) + if y_reduce_axis: + self.dy = self.dy.sum(axis=tuple(y_reduce_axis), + keepdims=True) + + # after multiplying with vector one dimension is deleted from tensor + if len(x_shape) == 2 and x_shape[0] == 1: + dout = dout.sum(axis=-2) + if len(y_shape) == 2 and y_shape[1] == 1: + dout = dout.sum(axis=-1) + + self.dout = dout cls_name = "{0}_{1}".format(parent.__name__, "BF16") TestMatMulV2Bf16OneDNNOp.__name__ = cls_name