From e48c7cc31975bd01fd2ab94bc23bd884e19415bf Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sat, 23 Jul 2022 01:34:10 +0800 Subject: [PATCH 01/13] Add kernel declarations --- .../phi/kernels/spectral_norm_grad_kernel.h | 29 +++++++++++++++++++ paddle/phi/kernels/spectral_norm_kernel.h | 28 ++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 paddle/phi/kernels/spectral_norm_grad_kernel.h create mode 100644 paddle/phi/kernels/spectral_norm_kernel.h diff --git a/paddle/phi/kernels/spectral_norm_grad_kernel.h b/paddle/phi/kernels/spectral_norm_grad_kernel.h new file mode 100644 index 0000000000000..047b22d02ac6c --- /dev/null +++ b/paddle/phi/kernels/spectral_norm_grad_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SpectrumNormGradKernel(const Context& dev_ctx + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + const DenseTensor& out_grad, + int dim, + int power_iters, + float eps, + DenseTensor* weight_grad); + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/spectral_norm_kernel.h b/paddle/phi/kernels/spectral_norm_kernel.h new file mode 100644 index 0000000000000..5c638027fbf2d --- /dev/null +++ b/paddle/phi/kernels/spectral_norm_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SpectrumNormKernel(const Context& dev_ctx + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + int dim, + int power_iters, + float eps, + DenseTensor* out); + +} // namespace phi \ No newline at end of file From f83a3cde315e4bdd934b576eb946aef2521f3873 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sat, 23 Jul 2022 01:41:22 +0800 Subject: [PATCH 02/13] Copy kernel implementation code --- .../kernels/cpu/spectral_norm_grad_kernel.cc | 13 ++ .../phi/kernels/cpu/spectral_norm_kernel.cc | 13 ++ paddle/phi/kernels/funcs/spectral_norm.h | 109 +++++++++++++++ .../kernels/gpu/spectral_norm_grad_kernel.cu | 13 ++ .../phi/kernels/gpu/spectral_norm_kernel.cu | 13 ++ .../impl/spectral_norm_grad_kernel_impl.h | 132 ++++++++++++++++++ .../kernels/impl/spectral_norm_kernel_impl.h | 104 ++++++++++++++ 7 files changed, 397 insertions(+) create mode 100644 paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/spectral_norm_kernel.cc create mode 100644 paddle/phi/kernels/funcs/spectral_norm.h create mode 100644 paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/spectral_norm_kernel.cu create mode 100644 paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/spectral_norm_kernel_impl.h diff --git a/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc new file mode 100644 index 0000000000000..564892f1679df --- /dev/null +++ b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc @@ -0,0 +1,13 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. \ No newline at end of file diff --git a/paddle/phi/kernels/cpu/spectral_norm_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc new file mode 100644 index 0000000000000..564892f1679df --- /dev/null +++ b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc @@ -0,0 +1,13 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. \ No newline at end of file diff --git a/paddle/phi/kernels/funcs/spectral_norm.h b/paddle/phi/kernels/funcs/spectral_norm.h new file mode 100644 index 0000000000000..75c460b5382aa --- /dev/null +++ b/paddle/phi/kernels/funcs/spectral_norm.h @@ -0,0 +1,109 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { + +template +using EigenTensor = framework::EigenTensor; +using Tensor = framework::Tensor; + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using IndexPair = Eigen::IndexPair; + +template +static inline void TransCompute(const int rank, + const Tensor& in, + Tensor* out, + const std::vector& perm, + const DeviceContext& dev_ctx) { + if (rank <= 1 || rank > 5) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", + rank)); + } + + switch (rank) { + case 2: + phi::funcs::Transpose trans2; + trans2(dev_ctx, in, out, perm); + break; + case 3: + phi::funcs::Transpose trans3; + trans3(dev_ctx, in, out, perm); + break; + case 4: + phi::funcs::Transpose trans4; + trans4(dev_ctx, in, out, perm); + break; + case 5: + phi::funcs::Transpose trans5; + trans5(dev_ctx, in, out, perm); + break; + default: + break; + } +} + +template +static inline void CalcMatrixSigmaAndNormWeight( + Tensor* sigma, + Tensor* u, + Tensor* v, + Tensor* weight, + const int power_iters, + const float eps, + const framework::ExecutionContext& ctx) { + auto& place = *ctx.template device_context().eigen_device(); + auto blas = phi::funcs::GetBlas(ctx); + auto sigma_t = EigenTensor::From(*sigma); + auto weight_t = EigenTensor::From(*weight); + auto u_t = EigenTensor::From(*u); + auto v_t = EigenTensor::From(*v); + + const int h = weight->dims()[0]; + const int w = weight->dims()[1]; + + for (int i = 0; i < power_iters; i++) { + // V = W^T * U / ||W^T * U||_2 + blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); + auto v_t_norm = + v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(w)); + v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); + // U = W^T * V / ||W^T * V||_2 + blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); + auto u_t_norm = + u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(h)); + u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); + } + Tensor weight_v; + weight_v.mutable_data({h, 1}, ctx.GetPlace()); + blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); + auto weight_v_t = EigenTensor::From(weight_v); + sigma_t.device(place) = (u_t * weight_v_t) + .sum() + .eval() + .reshape(Array2(1, 1)) + .broadcast(Array2(h, w)); + weight_t.device(place) = weight_t / sigma_t; +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu new file mode 100644 index 0000000000000..564892f1679df --- /dev/null +++ b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/spectral_norm_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu new file mode 100644 index 0000000000000..564892f1679df --- /dev/null +++ b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. \ No newline at end of file diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h new file mode 100644 index 0000000000000..38e2a6ac45202 --- /dev/null +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -0,0 +1,132 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SpectrumNormGradKernel(const Context& dev_ctx + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + const DenseTensor& out_grad, + int dim, + int power_iters, + float eps, + DenseTensor* weight_grad){ + auto& place = *ctx.template device_context().eigen_device(); + auto& dev_ctx = ctx.template device_context(); + auto blas = phi::funcs::GetBlas(ctx); + auto weight = ctx.Input("Weight"); + auto u = ctx.Input("U"); + auto v = ctx.Input("V"); + auto out_grad = ctx.Input(framework::GradVarName("Out")); + auto weight_grad = ctx.Output(framework::GradVarName("Weight")); + + int dim = ctx.Attr("dim"); + int power_iters = ctx.Attr("power_iters"); + float eps = ctx.Attr("eps"); + + const int h = u->dims()[0]; + const int w = v->dims()[0]; + + Tensor weight_mat, out_grad_mat; + auto dims = weight->dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); + out_grad_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); + TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); + TransCompute( + rank, *out_grad, &out_grad_mat, perm, dev_ctx); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + paddle::framework::TensorCopySync( + *out_grad, ctx.GetPlace(), &out_grad_mat); + } + weight_mat = weight_mat.Resize({h, w}); + out_grad_mat = out_grad_mat.Resize({h, w}); + + Tensor sigma; + sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); + Tensor uu, vv; + paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); + paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); + CalcMatrixSigmaAndNormWeight(&sigma, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &weight_mat, + power_iters, + eps, + ctx); + + Tensor uv; + uv.mutable_data({h, w}, ctx.GetPlace()); + blas.MatMul( + uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); + + Tensor weight_grad_mat; + weight_grad_mat.mutable_data({h, w}, ctx.GetPlace()); + auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); + auto weight_mat_t = EigenTensor::From(weight_mat); + auto out_grad_mat_t = EigenTensor::From(out_grad_mat); + auto sigma_t = EigenTensor::From(sigma); + auto uv_t = EigenTensor::From(uv); + weight_mat_t.device(place) = + weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); + weight_grad_mat_t.device(place) = + out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / + sigma_t; + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + weight_grad->mutable_data(dims, ctx.GetPlace()); + TransCompute( + rank, + weight_grad_mat.Resize(phi::make_ddim(real_dims)), + weight_grad, + perm, + dev_ctx); + } else { + paddle::framework::TensorCopySync( + weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h new file mode 100644 index 0000000000000..f4c414883bea1 --- /dev/null +++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SpectrumNormKernel(const Context& dev_ctx + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + int dim, + int power_iters, + float eps, + DenseTensor* out) + auto& dev_ctx = ctx.template device_context(); + auto weight = ctx.Input("Weight"); + auto u = ctx.Input("U"); + auto v = ctx.Input("V"); + auto out = ctx.Output("Out"); + + int dim = ctx.Attr("dim"); + int power_iters = ctx.Attr("power_iters"); + float eps = ctx.Attr("eps"); + + const int h = u->dims()[0]; + const int w = v->dims()[0]; + + Tensor weight_mat; + auto dims = weight->dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); + TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + } + weight_mat = weight_mat.Resize({h, w}); + + Tensor sigma; + sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); + Tensor uu, vv; + paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); + paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); + CalcMatrixSigmaAndNormWeight(&sigma, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &weight_mat, + power_iters, + eps, + ctx); + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + out->mutable_data(dims, ctx.GetPlace()); + TransCompute( + rank, + weight_mat.Resize(phi::make_ddim(real_dims)), + out, + perm, + dev_ctx); + } else { + paddle::framework::TensorCopySync( + weight_mat.Resize(dims), ctx.GetPlace(), out); + } +} + +} // namespace phi From 0fb645cced12f13f3fcffacab6cf3dc4dec28603 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sat, 23 Jul 2022 02:58:42 +0800 Subject: [PATCH 03/13] Transfer implementation code --- paddle/phi/kernels/funcs/spectral_norm.h | 52 ++++---- .../impl/spectral_norm_grad_kernel_impl.h | 113 ++++++++---------- .../kernels/impl/spectral_norm_kernel_impl.h | 84 ++++++------- 3 files changed, 115 insertions(+), 134 deletions(-) diff --git a/paddle/phi/kernels/funcs/spectral_norm.h b/paddle/phi/kernels/funcs/spectral_norm.h index 75c460b5382aa..b550e0a96b290 100644 --- a/paddle/phi/kernels/funcs/spectral_norm.h +++ b/paddle/phi/kernels/funcs/spectral_norm.h @@ -14,46 +14,41 @@ #pragma once -namespace phi { +#include "paddle/phi/kernels/funcs/eigen/common.h" -template -using EigenTensor = framework::EigenTensor; -using Tensor = framework::Tensor; +namespace phi { using Array1 = Eigen::DSizes; using Array2 = Eigen::DSizes; using IndexPair = Eigen::IndexPair; -template -static inline void TransCompute(const int rank, - const Tensor& in, - Tensor* out, +template +static inline void TransCompute2DTo5D(const Context& dev_ctx, + const DenseTensor& in, + const int rank, const std::vector& perm, - const DeviceContext& dev_ctx) { + DenseTensor* out) { if (rank <= 1 || rank > 5) { - PADDLE_THROW(paddle::platform::errors::Fatal( + PADDLE_THROW(errors::Fatal( "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", rank)); } switch (rank) { case 2: - phi::funcs::Transpose trans2; + phi::funcs::Transpose trans2; trans2(dev_ctx, in, out, perm); break; case 3: - phi::funcs::Transpose trans3; + phi::funcs::Transpose trans3; trans3(dev_ctx, in, out, perm); break; case 4: - phi::funcs::Transpose trans4; + phi::funcs::Transpose trans4; trans4(dev_ctx, in, out, perm); break; case 5: - phi::funcs::Transpose trans5; + phi::funcs::Transpose trans5; trans5(dev_ctx, in, out, perm); break; default: @@ -61,17 +56,17 @@ static inline void TransCompute(const int rank, } } -template +template static inline void CalcMatrixSigmaAndNormWeight( - Tensor* sigma, - Tensor* u, - Tensor* v, - Tensor* weight, + const Context& dev_ctx, + DenseTensor* weight, + DenseTensor* u, + DenseTensor* v, + DenseTensor* sigma, const int power_iters, - const float eps, - const framework::ExecutionContext& ctx) { - auto& place = *ctx.template device_context().eigen_device(); - auto blas = phi::funcs::GetBlas(ctx); + const float eps) { + auto& place = *dev_ctx.eigen_device(); + auto blas = funcs::GetBlas(dev_ctx); auto sigma_t = EigenTensor::From(*sigma); auto weight_t = EigenTensor::From(*weight); auto u_t = EigenTensor::From(*u); @@ -94,8 +89,9 @@ static inline void CalcMatrixSigmaAndNormWeight( Array1(h)); u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); } - Tensor weight_v; - weight_v.mutable_data({h, 1}, ctx.GetPlace()); + DenseTensor weight_v; + weight_v.Resize({h, 1}); + dev_ctx.template Alloc(&weight_v); blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); auto weight_v_t = EigenTensor::From(weight_v); sigma_t.device(place) = (u_t * weight_v_t) diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h index 38e2a6ac45202..edb47debe0811 100644 --- a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/spectral_norm.h" namespace phi { @@ -28,24 +28,14 @@ void SpectrumNormGradKernel(const Context& dev_ctx int power_iters, float eps, DenseTensor* weight_grad){ - auto& place = *ctx.template device_context().eigen_device(); - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(ctx); - auto weight = ctx.Input("Weight"); - auto u = ctx.Input("U"); - auto v = ctx.Input("V"); - auto out_grad = ctx.Input(framework::GradVarName("Out")); - auto weight_grad = ctx.Output(framework::GradVarName("Weight")); + auto& place = *dev_ctx.eigen_device(); + auto blas = phi::funcs::GetBlas(dev_ctx); - int dim = ctx.Attr("dim"); - int power_iters = ctx.Attr("power_iters"); - float eps = ctx.Attr("eps"); + const int h = u.dims()[0]; + const int w = v.dims()[0]; - const int h = u->dims()[0]; - const int w = v->dims()[0]; - - Tensor weight_mat, out_grad_mat; - auto dims = weight->dims(); + DenseTensor weight_mat, out_grad_mat; + auto dims = weight.dims(); const int rank = dims.size(); std::vector real_dims; if (dim != 0) { @@ -53,47 +43,50 @@ void SpectrumNormGradKernel(const Context& dev_ctx perm.push_back(dim); real_dims.push_back(dims[dim]); for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); - } + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } } - weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - out_grad_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); - TransCompute( - rank, *out_grad, &out_grad_mat, perm, dev_ctx); + weight_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&weight_mat); + out_grad_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&out_grad_mat); + TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); + TransCompute2DTo5D(dev_ctx, out_grad, rank, perm, &out_grad_mat); } else { for (int i = 0; i < rank; i++) { - real_dims.push_back(i); + real_dims.push_back(i); } - paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); - paddle::framework::TensorCopySync( - *out_grad, ctx.GetPlace(), &out_grad_mat); + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), false, &weight_mat); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, &out_grad_mat); } weight_mat = weight_mat.Resize({h, w}); out_grad_mat = out_grad_mat.Resize({h, w}); - Tensor sigma; - sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); - paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); - CalcMatrixSigmaAndNormWeight(&sigma, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &weight_mat, - power_iters, - eps, - ctx); + DenseTensor sigma; + sigma.Resize(weight_mat.dims()); + dev_ctx.template Alloc(&sigma); + DenseTensor uu, vv; + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), false, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), false, &vv); + CalcMatrixSigmaAndNormWeight(dev_ctx, + &weight_mat, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &sigma, + power_iters, + eps); - Tensor uv; - uv.mutable_data({h, w}, ctx.GetPlace()); + DenseTensor uv; + uv.Resize({h, w}); + dev_ctx.template Alloc(&uv); blas.MatMul( uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); - Tensor weight_grad_mat; - weight_grad_mat.mutable_data({h, w}, ctx.GetPlace()); + DenseTensor weight_grad_mat; + weight_grad_mat.Resize({h, w}); + dev_ctx.template Alloc(&weight_grad_mat); auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); auto weight_mat_t = EigenTensor::From(weight_mat); auto out_grad_mat_t = EigenTensor::From(out_grad_mat); @@ -108,24 +101,24 @@ void SpectrumNormGradKernel(const Context& dev_ctx if (dim != 0) { std::vector perm; for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } } - weight_grad->mutable_data(dims, ctx.GetPlace()); - TransCompute( - rank, + weight_grad->Resize(dims); + dev_ctx.template Alloc(weight_grad); + TransCompute2DTo5D( + dev_ctx, weight_grad_mat.Resize(phi::make_ddim(real_dims)), - weight_grad, + rank, perm, - dev_ctx); + weight_grad); } else { - paddle::framework::TensorCopySync( - weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad); + phi::Copy(dev_ctx, weight_grad_mat.Resize(dims), dev_ctx.GetPlace(), false, weight_grad); } } diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h index f4c414883bea1..1f7d4e873537b 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/spectral_norm.h" namespace phi { @@ -27,21 +27,11 @@ void SpectrumNormKernel(const Context& dev_ctx int power_iters, float eps, DenseTensor* out) - auto& dev_ctx = ctx.template device_context(); - auto weight = ctx.Input("Weight"); - auto u = ctx.Input("U"); - auto v = ctx.Input("V"); - auto out = ctx.Output("Out"); + const int h = u.dims()[0]; + const int w = v.dims()[0]; - int dim = ctx.Attr("dim"); - int power_iters = ctx.Attr("power_iters"); - float eps = ctx.Attr("eps"); - - const int h = u->dims()[0]; - const int w = v->dims()[0]; - - Tensor weight_mat; - auto dims = weight->dims(); + DenseTensor weight_mat; + auto dims = weight.dims(); const int rank = dims.size(); std::vector real_dims; if (dim != 0) { @@ -49,55 +39,57 @@ void SpectrumNormKernel(const Context& dev_ctx perm.push_back(dim); real_dims.push_back(dims[dim]); for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } } - } - weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); + weight_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&weight_mat); + TransCompute2DTo5D(rank, weight, &weight_mat, perm, dev_ctx); } else { for (int i = 0; i < rank; i++) { - real_dims.push_back(i); + real_dims.push_back(i); } - paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), false, &weight_mat); } weight_mat = weight_mat.Resize({h, w}); - Tensor sigma; - sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); - paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); - CalcMatrixSigmaAndNormWeight(&sigma, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &weight_mat, - power_iters, - eps, - ctx); + DenseTensor sigma; + sigma.Resize(weight_mat.dims()); + dev_ctx.template Alloc(sigma); + DenseTensor uu, vv; + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), false, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), false, &vv); + CalcMatrixSigmaAndNormWeight(dev_ctx, + &weight_mat, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &sigma, + power_iters, + eps); if (dim != 0) { std::vector perm; for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } } - out->mutable_data(dims, ctx.GetPlace()); - TransCompute( + out->Resize(dims) + dev_ctx.template Alloc(out); + TransCompute2DTo5D( rank, weight_mat.Resize(phi::make_ddim(real_dims)), out, perm, dev_ctx); } else { - paddle::framework::TensorCopySync( - weight_mat.Resize(dims), ctx.GetPlace(), out); + phi::Copy(dev_ctx, weight_mat.Resize(dims), dev_ctx.GetPlace(), false, out); } } From 6b2acaacc66467c95a23455ec6300adf3553f58b Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sat, 23 Jul 2022 03:24:29 +0800 Subject: [PATCH 04/13] Fix: Move out_grad to first --- paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h | 2 +- paddle/phi/kernels/spectral_norm_grad_kernel.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h index edb47debe0811..02d6d1fd0647d 100644 --- a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -20,10 +20,10 @@ namespace phi { template void SpectrumNormGradKernel(const Context& dev_ctx + const DenseTensor& out_grad, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, - const DenseTensor& out_grad, int dim, int power_iters, float eps, diff --git a/paddle/phi/kernels/spectral_norm_grad_kernel.h b/paddle/phi/kernels/spectral_norm_grad_kernel.h index 047b22d02ac6c..8e89ac50918d0 100644 --- a/paddle/phi/kernels/spectral_norm_grad_kernel.h +++ b/paddle/phi/kernels/spectral_norm_grad_kernel.h @@ -17,10 +17,10 @@ namespace phi { template void SpectrumNormGradKernel(const Context& dev_ctx + const DenseTensor& out_grad, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, - const DenseTensor& out_grad, int dim, int power_iters, float eps, From 8a7ae3369c6477123f3fb9a7cf4b5f0e69858a6a Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sat, 23 Jul 2022 03:25:48 +0800 Subject: [PATCH 05/13] Register new kernels --- .../kernels/cpu/spectral_norm_grad_kernel.cc | 15 ++++++++- .../phi/kernels/cpu/spectral_norm_kernel.cc | 15 ++++++++- .../kernels/gpu/spectral_norm_grad_kernel.cu | 15 ++++++++- .../phi/kernels/gpu/spectral_norm_kernel.cu | 15 ++++++++- paddle/phi/ops/compat/spectral_norm_sig.cc | 32 +++++++++++++++++++ 5 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 paddle/phi/ops/compat/spectral_norm_sig.cc diff --git a/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc index 564892f1679df..905f10d780b13 100644 --- a/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc @@ -10,4 +10,17 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/backends/cpu/cpu_context.h" + +#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" +#include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(spectral_norm_grad, + CPU, + ALL_LAYOUT, + phi::SpectralNormGradKernel, + float, + double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/cpu/spectral_norm_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc index 564892f1679df..b9d496a55e797 100644 --- a/paddle/phi/kernels/cpu/spectral_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc @@ -10,4 +10,17 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/backends/cpu/cpu_context.h" + +#include "paddle/phi/kernels/spectral_norm_kernel.h" +#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" + +PD_REGISTER_KERNEL(spectral_norm, + CPU, + ALL_LAYOUT, + phi::SpectralNormKernel, + float, + double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu index 564892f1679df..d4b747eb39f45 100644 --- a/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu @@ -10,4 +10,17 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/backends/gpu/gpu_context.h" + +#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" +#include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(spectral_norm_grad, + GPU, + ALL_LAYOUT, + phi::SpectralNormGradKernel, + float, + double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/spectral_norm_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu index 564892f1679df..4a6223ff3570d 100644 --- a/paddle/phi/kernels/gpu/spectral_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu @@ -10,4 +10,17 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/backends/gpu/gpu_context.h" + +#include "paddle/phi/kernels/spectral_norm_kernel.h" +#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" + +PD_REGISTER_KERNEL(spectral_norm, + GPU, + ALL_LAYOUT, + phi::SpectralNormKernel, + float, + double) {} \ No newline at end of file diff --git a/paddle/phi/ops/compat/spectral_norm_sig.cc b/paddle/phi/ops/compat/spectral_norm_sig.cc new file mode 100644 index 0000000000000..a061d47fc9008 --- /dev/null +++ b/paddle/phi/ops/compat/spectral_norm_sig.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi{ + +KernelSignature SpectralNormOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("spectral_norm", {"Weight", "U", "V"}, {"dim", "power_iters", "eps"}, {"Out"}); +} + +KernelSignature SpectralNormGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "spectral_norm_grad", {"Out@GRAD", "Weight", "U", "V"}, {"dim", "power_iters", "eps"}, {"Weight@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(spectral_norm, phi::SpectralNormOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(spectral_norm_grad, phi::SpectralNormGradOpArgumentMapping); From d061908916c175cbc25fef9fc2c18e88b84338bf Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sat, 23 Jul 2022 03:27:10 +0800 Subject: [PATCH 06/13] Remove old kernels --- paddle/fluid/operators/spectral_norm_op.cc | 6 - paddle/fluid/operators/spectral_norm_op.cu | 22 -- paddle/fluid/operators/spectral_norm_op.h | 299 --------------------- 3 files changed, 327 deletions(-) delete mode 100644 paddle/fluid/operators/spectral_norm_op.cu delete mode 100644 paddle/fluid/operators/spectral_norm_op.h diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index a6addb2e6f46d..ff4c44381dc49 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -256,9 +256,3 @@ REGISTER_OPERATOR(spectral_norm, ops::SpectralNormGradOpMaker, ops::SpectralNormGradOpMaker); REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad); -REGISTER_OP_CPU_KERNEL(spectral_norm, - ops::SpectralNormKernel, - ops::SpectralNormKernel); -REGISTER_OP_CPU_KERNEL(spectral_norm_grad, - ops::SpectralNormGradKernel, - ops::SpectralNormGradKernel); diff --git a/paddle/fluid/operators/spectral_norm_op.cu b/paddle/fluid/operators/spectral_norm_op.cu deleted file mode 100644 index ea90e3b4c122b..0000000000000 --- a/paddle/fluid/operators/spectral_norm_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/spectral_norm_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - spectral_norm, - ops::SpectralNormKernel, - ops::SpectralNormKernel); -REGISTER_OP_CUDA_KERNEL( - spectral_norm_grad, - ops::SpectralNormGradKernel, - ops::SpectralNormGradKernel); diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h deleted file mode 100644 index ffe8a40c35a46..0000000000000 --- a/paddle/fluid/operators/spectral_norm_op.h +++ /dev/null @@ -1,299 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -using EigenTensor = framework::EigenTensor; -using Tensor = framework::Tensor; - -using Array1 = Eigen::DSizes; -using Array2 = Eigen::DSizes; -using IndexPair = Eigen::IndexPair; - -template -static inline void TransCompute(const int rank, - const Tensor& in, - Tensor* out, - const std::vector& perm, - const DeviceContext& dev_ctx) { - if (rank <= 1 || rank > 5) { - PADDLE_THROW(paddle::platform::errors::Fatal( - "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", - rank)); - } - - switch (rank) { - case 2: - phi::funcs::Transpose trans2; - trans2(dev_ctx, in, out, perm); - break; - case 3: - phi::funcs::Transpose trans3; - trans3(dev_ctx, in, out, perm); - break; - case 4: - phi::funcs::Transpose trans4; - trans4(dev_ctx, in, out, perm); - break; - case 5: - phi::funcs::Transpose trans5; - trans5(dev_ctx, in, out, perm); - break; - default: - break; - } -} - -template -static inline void CalcMatrixSigmaAndNormWeight( - Tensor* sigma, - Tensor* u, - Tensor* v, - Tensor* weight, - const int power_iters, - const float eps, - const framework::ExecutionContext& ctx) { - auto& place = *ctx.template device_context().eigen_device(); - auto blas = phi::funcs::GetBlas(ctx); - auto sigma_t = EigenTensor::From(*sigma); - auto weight_t = EigenTensor::From(*weight); - auto u_t = EigenTensor::From(*u); - auto v_t = EigenTensor::From(*v); - - const int h = weight->dims()[0]; - const int w = weight->dims()[1]; - - for (int i = 0; i < power_iters; i++) { - // V = W^T * U / ||W^T * U||_2 - blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); - auto v_t_norm = - v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( - Array1(w)); - v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); - // U = W^T * V / ||W^T * V||_2 - blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); - auto u_t_norm = - u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( - Array1(h)); - u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); - } - Tensor weight_v; - weight_v.mutable_data({h, 1}, ctx.GetPlace()); - blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); - auto weight_v_t = EigenTensor::From(weight_v); - sigma_t.device(place) = (u_t * weight_v_t) - .sum() - .eval() - .reshape(Array2(1, 1)) - .broadcast(Array2(h, w)); - weight_t.device(place) = weight_t / sigma_t; -} - -template -class SpectralNormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - auto weight = ctx.Input("Weight"); - auto u = ctx.Input("U"); - auto v = ctx.Input("V"); - auto out = ctx.Output("Out"); - - int dim = ctx.Attr("dim"); - int power_iters = ctx.Attr("power_iters"); - float eps = ctx.Attr("eps"); - - const int h = u->dims()[0]; - const int w = v->dims()[0]; - - Tensor weight_mat; - auto dims = weight->dims(); - const int rank = dims.size(); - std::vector real_dims; - if (dim != 0) { - std::vector perm; - perm.push_back(dim); - real_dims.push_back(dims[dim]); - for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); - } - } - weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); - } else { - for (int i = 0; i < rank; i++) { - real_dims.push_back(i); - } - paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); - } - weight_mat = weight_mat.Resize({h, w}); - - Tensor sigma; - sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); - paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); - CalcMatrixSigmaAndNormWeight(&sigma, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &weight_mat, - power_iters, - eps, - ctx); - - if (dim != 0) { - std::vector perm; - for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } - } - out->mutable_data(dims, ctx.GetPlace()); - TransCompute( - rank, - weight_mat.Resize(phi::make_ddim(real_dims)), - out, - perm, - dev_ctx); - } else { - paddle::framework::TensorCopySync( - weight_mat.Resize(dims), ctx.GetPlace(), out); - } - } -}; - -template -class SpectralNormGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& place = *ctx.template device_context().eigen_device(); - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(ctx); - auto weight = ctx.Input("Weight"); - auto u = ctx.Input("U"); - auto v = ctx.Input("V"); - auto out_grad = ctx.Input(framework::GradVarName("Out")); - auto weight_grad = ctx.Output(framework::GradVarName("Weight")); - - int dim = ctx.Attr("dim"); - int power_iters = ctx.Attr("power_iters"); - float eps = ctx.Attr("eps"); - - const int h = u->dims()[0]; - const int w = v->dims()[0]; - - Tensor weight_mat, out_grad_mat; - auto dims = weight->dims(); - const int rank = dims.size(); - std::vector real_dims; - if (dim != 0) { - std::vector perm; - perm.push_back(dim); - real_dims.push_back(dims[dim]); - for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); - } - } - weight_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - out_grad_mat.mutable_data(phi::make_ddim(real_dims), ctx.GetPlace()); - TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); - TransCompute( - rank, *out_grad, &out_grad_mat, perm, dev_ctx); - } else { - for (int i = 0; i < rank; i++) { - real_dims.push_back(i); - } - paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); - paddle::framework::TensorCopySync( - *out_grad, ctx.GetPlace(), &out_grad_mat); - } - weight_mat = weight_mat.Resize({h, w}); - out_grad_mat = out_grad_mat.Resize({h, w}); - - Tensor sigma; - sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); - Tensor uu, vv; - paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu); - paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv); - CalcMatrixSigmaAndNormWeight(&sigma, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &weight_mat, - power_iters, - eps, - ctx); - - Tensor uv; - uv.mutable_data({h, w}, ctx.GetPlace()); - blas.MatMul( - uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); - - Tensor weight_grad_mat; - weight_grad_mat.mutable_data({h, w}, ctx.GetPlace()); - auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); - auto weight_mat_t = EigenTensor::From(weight_mat); - auto out_grad_mat_t = EigenTensor::From(out_grad_mat); - auto sigma_t = EigenTensor::From(sigma); - auto uv_t = EigenTensor::From(uv); - weight_mat_t.device(place) = - weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); - weight_grad_mat_t.device(place) = - out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / - sigma_t; - - if (dim != 0) { - std::vector perm; - for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } - } - weight_grad->mutable_data(dims, ctx.GetPlace()); - TransCompute( - rank, - weight_grad_mat.Resize(phi::make_ddim(real_dims)), - weight_grad, - perm, - dev_ctx); - } else { - paddle::framework::TensorCopySync( - weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad); - } - } -}; - -} // namespace operators -} // namespace paddle From 59ef331d840cb9bfd9b8de0f14e94b85146a1b3c Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Sat, 23 Jul 2022 13:31:19 +0800 Subject: [PATCH 07/13] Move out_grad to last --- paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h | 2 +- paddle/phi/kernels/spectral_norm_grad_kernel.h | 2 +- paddle/phi/ops/compat/spectral_norm_sig.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h index 02d6d1fd0647d..edb47debe0811 100644 --- a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -20,10 +20,10 @@ namespace phi { template void SpectrumNormGradKernel(const Context& dev_ctx - const DenseTensor& out_grad, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, + const DenseTensor& out_grad, int dim, int power_iters, float eps, diff --git a/paddle/phi/kernels/spectral_norm_grad_kernel.h b/paddle/phi/kernels/spectral_norm_grad_kernel.h index 8e89ac50918d0..047b22d02ac6c 100644 --- a/paddle/phi/kernels/spectral_norm_grad_kernel.h +++ b/paddle/phi/kernels/spectral_norm_grad_kernel.h @@ -17,10 +17,10 @@ namespace phi { template void SpectrumNormGradKernel(const Context& dev_ctx - const DenseTensor& out_grad, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, + const DenseTensor& out_grad, int dim, int power_iters, float eps, diff --git a/paddle/phi/ops/compat/spectral_norm_sig.cc b/paddle/phi/ops/compat/spectral_norm_sig.cc index a061d47fc9008..16705ba76390d 100644 --- a/paddle/phi/ops/compat/spectral_norm_sig.cc +++ b/paddle/phi/ops/compat/spectral_norm_sig.cc @@ -23,7 +23,7 @@ KernelSignature SpectralNormOpArgumentMapping(const ArgumentMappingContext& ctx) KernelSignature SpectralNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "spectral_norm_grad", {"Out@GRAD", "Weight", "U", "V"}, {"dim", "power_iters", "eps"}, {"Weight@GRAD"}); + "spectral_norm_grad", {"Weight", "U", "V", "Out@GRAD"}, {"dim", "power_iters", "eps"}, {"Weight@GRAD"}); } } // namespace phi From 032313bcfd96b25929edfff0c378d8fcc3447dc0 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Mon, 25 Jul 2022 13:39:34 +0800 Subject: [PATCH 08/13] Fix bugs --- paddle/fluid/operators/spectral_norm_op.cc | 2 -- paddle/phi/kernels/funcs/spectral_norm.h | 2 ++ .../impl/spectral_norm_grad_kernel_impl.h | 12 +++++----- .../kernels/impl/spectral_norm_kernel_impl.h | 24 +++++++++---------- .../phi/kernels/spectral_norm_grad_kernel.h | 2 +- paddle/phi/kernels/spectral_norm_kernel.h | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index ff4c44381dc49..17076e207de49 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -9,8 +9,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/spectral_norm_op.h" - #include #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/phi/kernels/funcs/spectral_norm.h b/paddle/phi/kernels/funcs/spectral_norm.h index b550e0a96b290..41c27e87ab0e4 100644 --- a/paddle/phi/kernels/funcs/spectral_norm.h +++ b/paddle/phi/kernels/funcs/spectral_norm.h @@ -15,6 +15,8 @@ #pragma once #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h index edb47debe0811..808611be72572 100644 --- a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -19,7 +19,7 @@ namespace phi { template -void SpectrumNormGradKernel(const Context& dev_ctx +void SpectralNormGradKernel(const Context& dev_ctx, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, @@ -58,8 +58,8 @@ void SpectrumNormGradKernel(const Context& dev_ctx for (int i = 0; i < rank; i++) { real_dims.push_back(i); } - phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), false, &weight_mat); - phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, &out_grad_mat); + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), true, &out_grad_mat); } weight_mat = weight_mat.Resize({h, w}); out_grad_mat = out_grad_mat.Resize({h, w}); @@ -68,8 +68,8 @@ void SpectrumNormGradKernel(const Context& dev_ctx sigma.Resize(weight_mat.dims()); dev_ctx.template Alloc(&sigma); DenseTensor uu, vv; - phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), false, &uu); - phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), false, &vv); + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); CalcMatrixSigmaAndNormWeight(dev_ctx, &weight_mat, &(uu.Resize({h, 1})), @@ -118,7 +118,7 @@ void SpectrumNormGradKernel(const Context& dev_ctx perm, weight_grad); } else { - phi::Copy(dev_ctx, weight_grad_mat.Resize(dims), dev_ctx.GetPlace(), false, weight_grad); + phi::Copy(dev_ctx, weight_grad_mat.Resize(dims), dev_ctx.GetPlace(), true, weight_grad); } } diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h index 1f7d4e873537b..99ff414d0f35e 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h @@ -19,14 +19,14 @@ namespace phi { template -void SpectrumNormKernel(const Context& dev_ctx +void SpectralNormKernel(const Context& dev_ctx, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, int dim, int power_iters, float eps, - DenseTensor* out) + DenseTensor* out){ const int h = u.dims()[0]; const int w = v.dims()[0]; @@ -46,21 +46,21 @@ void SpectrumNormKernel(const Context& dev_ctx } weight_mat.Resize(phi::make_ddim(real_dims)); dev_ctx.template Alloc(&weight_mat); - TransCompute2DTo5D(rank, weight, &weight_mat, perm, dev_ctx); + TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); } else { for (int i = 0; i < rank; i++) { real_dims.push_back(i); } - phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), false, &weight_mat); + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); } weight_mat = weight_mat.Resize({h, w}); DenseTensor sigma; sigma.Resize(weight_mat.dims()); - dev_ctx.template Alloc(sigma); + dev_ctx.template Alloc(&sigma); DenseTensor uu, vv; - phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), false, &uu); - phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), false, &vv); + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); CalcMatrixSigmaAndNormWeight(dev_ctx, &weight_mat, &(uu.Resize({h, 1})), @@ -80,16 +80,16 @@ void SpectrumNormKernel(const Context& dev_ctx perm.push_back(i); } } - out->Resize(dims) + out->Resize(dims); dev_ctx.template Alloc(out); TransCompute2DTo5D( - rank, + dev_ctx, weight_mat.Resize(phi::make_ddim(real_dims)), - out, + rank, perm, - dev_ctx); + out); } else { - phi::Copy(dev_ctx, weight_mat.Resize(dims), dev_ctx.GetPlace(), false, out); + phi::Copy(dev_ctx, weight_mat.Resize(dims), dev_ctx.GetPlace(), true, out); } } diff --git a/paddle/phi/kernels/spectral_norm_grad_kernel.h b/paddle/phi/kernels/spectral_norm_grad_kernel.h index 047b22d02ac6c..783633fd5cb02 100644 --- a/paddle/phi/kernels/spectral_norm_grad_kernel.h +++ b/paddle/phi/kernels/spectral_norm_grad_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ namespace phi { template -void SpectrumNormGradKernel(const Context& dev_ctx +void SpectrumNormGradKernel(const Context& dev_ctx, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, diff --git a/paddle/phi/kernels/spectral_norm_kernel.h b/paddle/phi/kernels/spectral_norm_kernel.h index 5c638027fbf2d..89083b994e6e8 100644 --- a/paddle/phi/kernels/spectral_norm_kernel.h +++ b/paddle/phi/kernels/spectral_norm_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ namespace phi { template -void SpectrumNormKernel(const Context& dev_ctx +void SpectrumNormKernel(const Context& dev_ctx, const DenseTensor& weight, const DenseTensor& u, const DenseTensor& v, From 1fc902f35e69faa44450e4a12f518ac238a632b6 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Mon, 25 Jul 2022 14:26:10 +0800 Subject: [PATCH 09/13] Transfer infermeta --- paddle/fluid/operators/spectral_norm_op.cc | 113 +++------------------ paddle/phi/infermeta/backward.cc | 15 +++ paddle/phi/infermeta/backward.h | 9 ++ paddle/phi/infermeta/ternary.cc | 76 ++++++++++++++ paddle/phi/infermeta/ternary.h | 9 ++ 5 files changed, 124 insertions(+), 98 deletions(-) diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 17076e207de49..60f01cc56a13d 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -11,8 +11,12 @@ #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/ternary.h" + namespace paddle { namespace operators { @@ -22,82 +26,6 @@ class SpectralNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); - - auto dim_weight = ctx->GetInputDim("Weight"); - auto rank_weight = dim_weight.size(); - PADDLE_ENFORCE_GE(rank_weight, - 2, - platform::errors::InvalidArgument( - "The rank of Input(Weights) should be greater equal " - "than 2, but received Weight rank(%d)", - rank_weight)); - PADDLE_ENFORCE_LE(rank_weight, - 5, - platform::errors::InvalidArgument( - "The rank of Input(Weights) should be less equal " - "than 5, but received Weight rank(%d)", - rank_weight)); - - int dim = ctx->Attrs().Get("dim"); - int power_iters = ctx->Attrs().Get("power_iters"); - auto dim_valid = dim == 0 || dim == 1; - PADDLE_ENFORCE_EQ( - dim_valid, - true, - platform::errors::InvalidArgument( - "Attr(dim) can only be 0 or 1, but received %d", dim)); - PADDLE_ENFORCE_GE( - power_iters, - 0, - platform::errors::InvalidArgument( - "Attr(power_iters) should be greater equal then 0, but received %d", - power_iters)); - - int h = dim_weight[dim]; - int w = 1; - for (int i = 0; i < rank_weight; i++) { - if (i != dim) { - w *= dim_weight[i]; - } - } - auto dim_u = ctx->GetInputDim("U"); - auto dim_v = ctx->GetInputDim("V"); - - if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) { - PADDLE_ENFORCE_EQ(dim_u[0], - h, - platform::errors::InvalidArgument( - "Input(U) dimension[0] should be equal to " - "Input(Weight) dimension[Attr(dim)], but received " - "U dimension[0](%d) != Weight dimension[%d](%d)", - dim_u[0], - dim, - h)); - } - - if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) { - PADDLE_ENFORCE_EQ( - dim_v[0], - w, - platform::errors::InvalidArgument( - "Input(V) dimension[0] should be equal to the product of " - "Input(Weight) dimension except dimension[Attr(dim)], but " - "received V dimension[0](%d) != product of Input(Weight) " - "dimension(%d)", - dim_v[0], - w)); - } - - ctx->SetOutputDim("Out", dim_weight); - ctx->ShareLoD("Weight", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -217,26 +145,6 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("Weight"), "Input", "Weight", "SpectralNormGrad"); - OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNormGrad"); - OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNormGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "SpectralNormGrad"); - - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Out")), - true, - platform::errors::NotFound("Input(Out@GRAD) should not be null")); - auto dim_x = ctx->GetInputDim("Weight"); - if (ctx->HasOutput(framework::GradVarName("Weight"))) { - ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -248,9 +156,18 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm, SpectralNormInferMetaFunctor, + PD_INFER_META(phi::SpectralNormInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm_grad, SpectralNormGradInferMetaFunctor, + PD_INFER_META(phi::SpectralNormGradInferMeta)); + REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker, ops::SpectralNormGradOpMaker, - ops::SpectralNormGradOpMaker); -REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad); + ops::SpectralNormGradOpMaker, + SpectralNormInferMetaFunctor); +REGISTER_OPERATOR(spectral_norm_grad, + ops::SpectralNormOpGrad, + SpectralNormGradInferMetaFunctor); \ No newline at end of file diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 3480af8db88d3..72c7bed5e3523 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -661,6 +661,21 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, } } +void SpectralNormGradInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + const MetaTensor& out_grad, + int dim, + int power_iters, + float eps, + MetaTensor* weight_grad){ + auto dim_x = weight.dims(); + if (weight_grad) { + weight_grad->set_dims(dim_x); + weight_grad->set_dtype(out_grad.dtype()); + } +} + void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 88825faa95f7c..76cb8f1bfaf28 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -277,6 +277,15 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, MetaTensor* x_grad, MetaTensor* updates_grad); +void SpectralNormGradInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + const MetaTensor& out_grad, + int dim, + int power_iters, + float eps, + MetaTensor* weight_grad); + void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 9f65de0f0aa70..b6c32f00b66ec 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -978,6 +978,82 @@ void ScatterNdAddInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void SpectralNormInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + int dim, + int power_iters, + float eps, + MetaTensor* out, + MetaConfig config){ + auto dim_weight = weight.dims(); + auto rank_weight = dim_weight.size(); + PADDLE_ENFORCE_GE(rank_weight, + 2, + errors::InvalidArgument( + "The rank of Input(Weights) should be greater equal " + "than 2, but received Weight rank(%d)", + rank_weight)); + PADDLE_ENFORCE_LE(rank_weight, + 5, + errors::InvalidArgument( + "The rank of Input(Weights) should be less equal " + "than 5, but received Weight rank(%d)", + rank_weight)); + + auto dim_valid = dim == 0 || dim == 1; + PADDLE_ENFORCE_EQ( + dim_valid, + true, + errors::InvalidArgument( + "Attr(dim) can only be 0 or 1, but received %d", dim)); + PADDLE_ENFORCE_GE( + power_iters, + 0, + errors::InvalidArgument( + "Attr(power_iters) should be greater equal then 0, but received %d", + power_iters)); + + int h = dim_weight[dim]; + int w = 1; + for (int i = 0; i < rank_weight; i++) { + if (i != dim) { + w *= dim_weight[i]; + } + } + auto dim_u = u.dims(); + auto dim_v = v.dims(); + + if (config.is_runtime || (dim_u[0] > 0 && h > 0)) { + PADDLE_ENFORCE_EQ(dim_u[0], + h, + errors::InvalidArgument( + "Input(U) dimension[0] should be equal to " + "Input(Weight) dimension[Attr(dim)], but received " + "U dimension[0](%d) != Weight dimension[%d](%d)", + dim_u[0], + dim, + h)); + } + + if (config.is_runtime || (dim_v[0] > 0 && w > 0)) { + PADDLE_ENFORCE_EQ( + dim_v[0], + w, + errors::InvalidArgument( + "Input(V) dimension[0] should be equal to the product of " + "Input(Weight) dimension except dimension[Attr(dim)], but " + "received V dimension[0](%d) != product of Input(Weight) " + "dimension(%d)", + dim_v[0], + w)); + } + + out->set_dims(dim_weight); + out->set_dtype(weight.dtype()); + out->share_lod(weight); +} + void ViterbiDecodeInferMeta(const MetaTensor& input, const MetaTensor& transition, const MetaTensor& length, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 40461d299fb01..7fd1ef81d916b 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -160,6 +160,15 @@ void ScatterNdAddInferMeta(const MetaTensor& x, const MetaTensor& updates, MetaTensor* out); +void SpectralNormInferMeta(const MetaTensor& weight, + const MetaTensor& u, + const MetaTensor& v, + int dim, + int power_iters, + float eps, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ViterbiDecodeInferMeta(const MetaTensor& input, const MetaTensor& transition, const MetaTensor& length, From 3de6db945c12c1570a6ee459382bfe86ba51c13c Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Mon, 25 Jul 2022 15:05:14 +0800 Subject: [PATCH 10/13] Add yaml files --- paddle/phi/api/yaml/legacy_api.yaml | 10 ++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index ad93a7c6072e7..9dff8614a2991 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2038,6 +2038,16 @@ use_gpudnn : true backward : softmax_grad +- api : spectral_norm + args : (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) + output : Tensor + infer_meta : + func : SpectralNormInferMeta + kernel : + func : spectralnorm + data_type : weight + backward : spectral_norm_grad + - api : split args : (Tensor x, IntArray num_or_sections, Scalar(int) axis) output : Tensor[] diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 61eeec6c848bb..c81575cfbbf41 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1963,6 +1963,16 @@ invoke : concat( out_grad, axis) # TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future. +- backward_api : spectral_norm_grad + forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out) + args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps) + output : Tensor(weight_grad) + infer_meta : + func : SpectralNormGradInferMeta + kernel : + func : spectral_norm_grad + data_type : out_grad + - backward_api : sqrt_double_grad forward : sqrt_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) args : (Tensor out, Tensor grad_x, Tensor grad_x_grad) From a478a37d3ef48a387f3a46b4b00e95542420881b Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Mon, 25 Jul 2022 16:04:43 +0800 Subject: [PATCH 11/13] Add blank line --- paddle/fluid/operators/spectral_norm_op.cc | 12 +++++++----- paddle/phi/api/yaml/legacy_api.yaml | 2 +- paddle/phi/api/yaml/legacy_backward.yaml | 14 +++++++------- .../kernels/cpu/spectral_norm_grad_kernel.cc | 14 +++++++------- paddle/phi/kernels/cpu/spectral_norm_kernel.cc | 12 ++++-------- .../kernels/gpu/spectral_norm_grad_kernel.cu | 14 +++++++------- paddle/phi/kernels/gpu/spectral_norm_kernel.cu | 12 ++++-------- paddle/phi/kernels/spectral_norm_grad_kernel.h | 18 +++++++++--------- paddle/phi/kernels/spectral_norm_kernel.h | 2 +- 9 files changed, 47 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 60f01cc56a13d..1d47a10d56bc2 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -157,9 +157,11 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm, SpectralNormInferMetaFunctor, +DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm, + SpectralNormInferMetaFunctor, PD_INFER_META(phi::SpectralNormInferMeta)); -DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm_grad, SpectralNormGradInferMetaFunctor, +DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm_grad, + SpectralNormGradInferMetaFunctor, PD_INFER_META(phi::SpectralNormGradInferMeta)); REGISTER_OPERATOR(spectral_norm, @@ -168,6 +170,6 @@ REGISTER_OPERATOR(spectral_norm, ops::SpectralNormGradOpMaker, ops::SpectralNormGradOpMaker, SpectralNormInferMetaFunctor); -REGISTER_OPERATOR(spectral_norm_grad, - ops::SpectralNormOpGrad, - SpectralNormGradInferMetaFunctor); \ No newline at end of file +REGISTER_OPERATOR(spectral_norm_grad, + ops::SpectralNormOpGrad, + SpectralNormGradInferMetaFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 9dff8614a2991..5b0497284db2c 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2046,7 +2046,7 @@ kernel : func : spectralnorm data_type : weight - backward : spectral_norm_grad + backward : spectral_norm_grad - api : split args : (Tensor x, IntArray num_or_sections, Scalar(int) axis) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index c81575cfbbf41..beb8b6bc2e055 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1956,13 +1956,6 @@ func : softmax_grad use_gpudnn : true -- backward_api : split_grad - forward : split (Tensor x, IntArray num_or_sections, Scalar axis) -> Tensor[](out) - args : (Tensor[] out_grad, Scalar axis = -1) - output : Tensor(x_grad) - invoke : concat( out_grad, axis) -# TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future. - - backward_api : spectral_norm_grad forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out) args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps) @@ -1973,6 +1966,13 @@ func : spectral_norm_grad data_type : out_grad +- backward_api : split_grad + forward : split (Tensor x, IntArray num_or_sections, Scalar axis) -> Tensor[](out) + args : (Tensor[] out_grad, Scalar axis = -1) + output : Tensor(x_grad) + invoke : concat( out_grad, axis) +# TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future. + - backward_api : sqrt_double_grad forward : sqrt_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) args : (Tensor out, Tensor grad_x, Tensor grad_x_grad) diff --git a/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc index 905f10d780b13..5603eb0f7cb91 100644 --- a/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/spectral_norm_grad_kernel.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" #include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" PD_REGISTER_KERNEL(spectral_norm_grad, - CPU, - ALL_LAYOUT, - phi::SpectralNormGradKernel, - float, - double) {} \ No newline at end of file + CPU, + ALL_LAYOUT, + phi::SpectralNormGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/spectral_norm_kernel.cc b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc index b9d496a55e797..6ff25365d1c16 100644 --- a/paddle/phi/kernels/cpu/spectral_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/spectral_norm_kernel.cc @@ -12,15 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/spectral_norm_kernel.h" #include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_kernel.h" -PD_REGISTER_KERNEL(spectral_norm, - CPU, - ALL_LAYOUT, - phi::SpectralNormKernel, - float, - double) {} \ No newline at end of file +PD_REGISTER_KERNEL( + spectral_norm, CPU, ALL_LAYOUT, phi::SpectralNormKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu index d4b747eb39f45..75c82e90fc059 100644 --- a/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/spectral_norm_grad_kernel.cu @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" #include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_grad_kernel.h" PD_REGISTER_KERNEL(spectral_norm_grad, - GPU, - ALL_LAYOUT, - phi::SpectralNormGradKernel, - float, - double) {} \ No newline at end of file + GPU, + ALL_LAYOUT, + phi::SpectralNormGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/spectral_norm_kernel.cu b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu index 4a6223ff3570d..7709cf5da1b5a 100644 --- a/paddle/phi/kernels/gpu/spectral_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/spectral_norm_kernel.cu @@ -12,15 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/spectral_norm_kernel.h" #include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" +#include "paddle/phi/kernels/spectral_norm_kernel.h" -PD_REGISTER_KERNEL(spectral_norm, - GPU, - ALL_LAYOUT, - phi::SpectralNormKernel, - float, - double) {} \ No newline at end of file +PD_REGISTER_KERNEL( + spectral_norm, GPU, ALL_LAYOUT, phi::SpectralNormKernel, float, double) {} diff --git a/paddle/phi/kernels/spectral_norm_grad_kernel.h b/paddle/phi/kernels/spectral_norm_grad_kernel.h index 783633fd5cb02..504cfba4b95e7 100644 --- a/paddle/phi/kernels/spectral_norm_grad_kernel.h +++ b/paddle/phi/kernels/spectral_norm_grad_kernel.h @@ -17,13 +17,13 @@ namespace phi { template void SpectrumNormGradKernel(const Context& dev_ctx, - const DenseTensor& weight, - const DenseTensor& u, - const DenseTensor& v, - const DenseTensor& out_grad, - int dim, - int power_iters, - float eps, - DenseTensor* weight_grad); + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + const DenseTensor& out_grad, + int dim, + int power_iters, + float eps, + DenseTensor* weight_grad); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/spectral_norm_kernel.h b/paddle/phi/kernels/spectral_norm_kernel.h index 89083b994e6e8..26b1699898ea6 100644 --- a/paddle/phi/kernels/spectral_norm_kernel.h +++ b/paddle/phi/kernels/spectral_norm_kernel.h @@ -25,4 +25,4 @@ void SpectrumNormKernel(const Context& dev_ctx, float eps, DenseTensor* out); -} // namespace phi \ No newline at end of file +} // namespace phi From a75de872186ad1826ec3fe6a48d9e94e5a04478d Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Tue, 26 Jul 2022 15:01:37 +0800 Subject: [PATCH 12/13] Fix code style --- paddle/phi/infermeta/backward.cc | 14 +- paddle/phi/infermeta/backward.h | 16 +- paddle/phi/infermeta/ternary.cc | 35 ++-- paddle/phi/infermeta/ternary.h | 16 +- paddle/phi/kernels/funcs/spectral_norm.h | 27 ++- .../impl/spectral_norm_grad_kernel_impl.h | 189 +++++++++--------- .../kernels/impl/spectral_norm_kernel_impl.h | 116 ++++++----- paddle/phi/ops/compat/spectral_norm_sig.cc | 21 +- 8 files changed, 220 insertions(+), 214 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 72c7bed5e3523..fbc2e2b73433b 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -662,13 +662,13 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, } void SpectralNormGradInferMeta(const MetaTensor& weight, - const MetaTensor& u, - const MetaTensor& v, - const MetaTensor& out_grad, - int dim, - int power_iters, - float eps, - MetaTensor* weight_grad){ + const MetaTensor& u, + const MetaTensor& v, + const MetaTensor& out_grad, + int dim, + int power_iters, + float eps, + MetaTensor* weight_grad) { auto dim_x = weight.dims(); if (weight_grad) { weight_grad->set_dims(dim_x); diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 76cb8f1bfaf28..4b897ba410014 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -278,14 +278,14 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, MetaTensor* updates_grad); void SpectralNormGradInferMeta(const MetaTensor& weight, - const MetaTensor& u, - const MetaTensor& v, - const MetaTensor& out_grad, - int dim, - int power_iters, - float eps, - MetaTensor* weight_grad); - + const MetaTensor& u, + const MetaTensor& v, + const MetaTensor& out_grad, + int dim, + int power_iters, + float eps, + MetaTensor* weight_grad); + void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index b6c32f00b66ec..3b7b655bd2ee5 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -979,13 +979,13 @@ void ScatterNdAddInferMeta(const MetaTensor& x, } void SpectralNormInferMeta(const MetaTensor& weight, - const MetaTensor& u, - const MetaTensor& v, - int dim, - int power_iters, - float eps, - MetaTensor* out, - MetaConfig config){ + const MetaTensor& u, + const MetaTensor& v, + int dim, + int power_iters, + float eps, + MetaTensor* out, + MetaConfig config) { auto dim_weight = weight.dims(); auto rank_weight = dim_weight.size(); PADDLE_ENFORCE_GE(rank_weight, @@ -994,19 +994,18 @@ void SpectralNormInferMeta(const MetaTensor& weight, "The rank of Input(Weights) should be greater equal " "than 2, but received Weight rank(%d)", rank_weight)); - PADDLE_ENFORCE_LE(rank_weight, - 5, - errors::InvalidArgument( - "The rank of Input(Weights) should be less equal " - "than 5, but received Weight rank(%d)", - rank_weight)); + PADDLE_ENFORCE_LE( + rank_weight, + 5, + errors::InvalidArgument("The rank of Input(Weights) should be less equal " + "than 5, but received Weight rank(%d)", + rank_weight)); auto dim_valid = dim == 0 || dim == 1; - PADDLE_ENFORCE_EQ( - dim_valid, - true, - errors::InvalidArgument( - "Attr(dim) can only be 0 or 1, but received %d", dim)); + PADDLE_ENFORCE_EQ(dim_valid, + true, + errors::InvalidArgument( + "Attr(dim) can only be 0 or 1, but received %d", dim)); PADDLE_ENFORCE_GE( power_iters, 0, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 7fd1ef81d916b..b6823da959d4a 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -161,14 +161,14 @@ void ScatterNdAddInferMeta(const MetaTensor& x, MetaTensor* out); void SpectralNormInferMeta(const MetaTensor& weight, - const MetaTensor& u, - const MetaTensor& v, - int dim, - int power_iters, - float eps, - MetaTensor* out, - MetaConfig config = MetaConfig()); - + const MetaTensor& u, + const MetaTensor& v, + int dim, + int power_iters, + float eps, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ViterbiDecodeInferMeta(const MetaTensor& input, const MetaTensor& transition, const MetaTensor& length, diff --git a/paddle/phi/kernels/funcs/spectral_norm.h b/paddle/phi/kernels/funcs/spectral_norm.h index 41c27e87ab0e4..290fbf59e2939 100644 --- a/paddle/phi/kernels/funcs/spectral_norm.h +++ b/paddle/phi/kernels/funcs/spectral_norm.h @@ -14,8 +14,8 @@ #pragma once -#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -26,12 +26,12 @@ using IndexPair = Eigen::IndexPair; template static inline void TransCompute2DTo5D(const Context& dev_ctx, - const DenseTensor& in, - const int rank, - const std::vector& perm, - DenseTensor* out) { + const DenseTensor& in, + const int rank, + const std::vector& perm, + DenseTensor* out) { if (rank <= 1 || rank > 5) { - PADDLE_THROW(errors::Fatal( + PADDLE_THROW(phi::errors::Fatal( "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", rank)); } @@ -59,14 +59,13 @@ static inline void TransCompute2DTo5D(const Context& dev_ctx, } template -static inline void CalcMatrixSigmaAndNormWeight( - const Context& dev_ctx, - DenseTensor* weight, - DenseTensor* u, - DenseTensor* v, - DenseTensor* sigma, - const int power_iters, - const float eps) { +static inline void CalcMatrixSigmaAndNormWeight(const Context& dev_ctx, + DenseTensor* weight, + DenseTensor* u, + DenseTensor* v, + DenseTensor* sigma, + const int power_iters, + const float eps) { auto& place = *dev_ctx.eigen_device(); auto blas = funcs::GetBlas(dev_ctx); auto sigma_t = EigenTensor::From(*sigma); diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h index 808611be72572..76282ae4b9410 100644 --- a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -20,106 +20,111 @@ namespace phi { template void SpectralNormGradKernel(const Context& dev_ctx, - const DenseTensor& weight, - const DenseTensor& u, - const DenseTensor& v, - const DenseTensor& out_grad, - int dim, - int power_iters, - float eps, - DenseTensor* weight_grad){ - auto& place = *dev_ctx.eigen_device(); - auto blas = phi::funcs::GetBlas(dev_ctx); + const DenseTensor& weight, + const DenseTensor& u, + const DenseTensor& v, + const DenseTensor& out_grad, + int dim, + int power_iters, + float eps, + DenseTensor* weight_grad) { + auto& place = *dev_ctx.eigen_device(); + auto blas = phi::funcs::GetBlas(dev_ctx); - const int h = u.dims()[0]; - const int w = v.dims()[0]; + const int h = u.dims()[0]; + const int w = v.dims()[0]; - DenseTensor weight_mat, out_grad_mat; - auto dims = weight.dims(); - const int rank = dims.size(); - std::vector real_dims; - if (dim != 0) { - std::vector perm; - perm.push_back(dim); - real_dims.push_back(dims[dim]); - for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); - } - } - weight_mat.Resize(phi::make_ddim(real_dims)); - dev_ctx.template Alloc(&weight_mat); - out_grad_mat.Resize(phi::make_ddim(real_dims)); - dev_ctx.template Alloc(&out_grad_mat); - TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); - TransCompute2DTo5D(dev_ctx, out_grad, rank, perm, &out_grad_mat); - } else { - for (int i = 0; i < rank; i++) { - real_dims.push_back(i); - } - phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); - phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), true, &out_grad_mat); + DenseTensor weight_mat, out_grad_mat; + auto dims = weight.dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } } - weight_mat = weight_mat.Resize({h, w}); - out_grad_mat = out_grad_mat.Resize({h, w}); + weight_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&weight_mat); + out_grad_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&out_grad_mat); + TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); + TransCompute2DTo5D( + dev_ctx, out_grad, rank, perm, &out_grad_mat); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), true, &out_grad_mat); + } + weight_mat = weight_mat.Resize({h, w}); + out_grad_mat = out_grad_mat.Resize({h, w}); - DenseTensor sigma; - sigma.Resize(weight_mat.dims()); - dev_ctx.template Alloc(&sigma); - DenseTensor uu, vv; - phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); - phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); - CalcMatrixSigmaAndNormWeight(dev_ctx, - &weight_mat, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &sigma, - power_iters, - eps); + DenseTensor sigma; + sigma.Resize(weight_mat.dims()); + dev_ctx.template Alloc(&sigma); + DenseTensor uu, vv; + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); + CalcMatrixSigmaAndNormWeight(dev_ctx, + &weight_mat, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &sigma, + power_iters, + eps); - DenseTensor uv; - uv.Resize({h, w}); - dev_ctx.template Alloc(&uv); - blas.MatMul( - uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); + DenseTensor uv; + uv.Resize({h, w}); + dev_ctx.template Alloc(&uv); + blas.MatMul( + uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); - DenseTensor weight_grad_mat; - weight_grad_mat.Resize({h, w}); - dev_ctx.template Alloc(&weight_grad_mat); - auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); - auto weight_mat_t = EigenTensor::From(weight_mat); - auto out_grad_mat_t = EigenTensor::From(out_grad_mat); - auto sigma_t = EigenTensor::From(sigma); - auto uv_t = EigenTensor::From(uv); - weight_mat_t.device(place) = - weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); - weight_grad_mat_t.device(place) = - out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / - sigma_t; + DenseTensor weight_grad_mat; + weight_grad_mat.Resize({h, w}); + dev_ctx.template Alloc(&weight_grad_mat); + auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); + auto weight_mat_t = EigenTensor::From(weight_mat); + auto out_grad_mat_t = EigenTensor::From(out_grad_mat); + auto sigma_t = EigenTensor::From(sigma); + auto uv_t = EigenTensor::From(uv); + weight_mat_t.device(place) = + weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); + weight_grad_mat_t.device(place) = + out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / + sigma_t; - if (dim != 0) { - std::vector perm; - for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } - } - weight_grad->Resize(dims); - dev_ctx.template Alloc(weight_grad); - TransCompute2DTo5D( - dev_ctx, - weight_grad_mat.Resize(phi::make_ddim(real_dims)), - rank, - perm, - weight_grad); - } else { - phi::Copy(dev_ctx, weight_grad_mat.Resize(dims), dev_ctx.GetPlace(), true, weight_grad); + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } } + weight_grad->Resize(dims); + dev_ctx.template Alloc(weight_grad); + TransCompute2DTo5D( + dev_ctx, + weight_grad_mat.Resize(phi::make_ddim(real_dims)), + rank, + perm, + weight_grad); + } else { + phi::Copy(dev_ctx, + weight_grad_mat.Resize(dims), + dev_ctx.GetPlace(), + true, + weight_grad); + } } } // namespace phi diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h index 99ff414d0f35e..70e13bb543515 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h @@ -26,71 +26,67 @@ void SpectralNormKernel(const Context& dev_ctx, int dim, int power_iters, float eps, - DenseTensor* out){ - const int h = u.dims()[0]; - const int w = v.dims()[0]; + DenseTensor* out) { + const int h = u.dims()[0]; + const int w = v.dims()[0]; - DenseTensor weight_mat; - auto dims = weight.dims(); - const int rank = dims.size(); - std::vector real_dims; - if (dim != 0) { - std::vector perm; - perm.push_back(dim); - real_dims.push_back(dims[dim]); - for (int i = 0; i < rank; i++) { - if (i != dim) { - perm.push_back(i); - real_dims.push_back(dims[i]); - } - } - weight_mat.Resize(phi::make_ddim(real_dims)); - dev_ctx.template Alloc(&weight_mat); - TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); - } else { - for (int i = 0; i < rank; i++) { - real_dims.push_back(i); - } - phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); + DenseTensor weight_mat; + auto dims = weight.dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } } - weight_mat = weight_mat.Resize({h, w}); + weight_mat.Resize(phi::make_ddim(real_dims)); + dev_ctx.template Alloc(&weight_mat); + TransCompute2DTo5D(dev_ctx, weight, rank, perm, &weight_mat); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat); + } + weight_mat = weight_mat.Resize({h, w}); - DenseTensor sigma; - sigma.Resize(weight_mat.dims()); - dev_ctx.template Alloc(&sigma); - DenseTensor uu, vv; - phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); - phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); - CalcMatrixSigmaAndNormWeight(dev_ctx, - &weight_mat, - &(uu.Resize({h, 1})), - &(vv.Resize({w, 1})), - &sigma, - power_iters, - eps); + DenseTensor sigma; + sigma.Resize(weight_mat.dims()); + dev_ctx.template Alloc(&sigma); + DenseTensor uu, vv; + phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu); + phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv); + CalcMatrixSigmaAndNormWeight(dev_ctx, + &weight_mat, + &(uu.Resize({h, 1})), + &(vv.Resize({w, 1})), + &sigma, + power_iters, + eps); - if (dim != 0) { - std::vector perm; - for (int i = 0; i < rank; i++) { - if (i < dim) { - perm.push_back(i + 1); - } else if (i == dim) { - perm.push_back(0); - } else { - perm.push_back(i); - } - } - out->Resize(dims); - dev_ctx.template Alloc(out); - TransCompute2DTo5D( - dev_ctx, - weight_mat.Resize(phi::make_ddim(real_dims)), - rank, - perm, - out); - } else { - phi::Copy(dev_ctx, weight_mat.Resize(dims), dev_ctx.GetPlace(), true, out); + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } } + out->Resize(dims); + dev_ctx.template Alloc(out); + TransCompute2DTo5D( + dev_ctx, weight_mat.Resize(phi::make_ddim(real_dims)), rank, perm, out); + } else { + phi::Copy(dev_ctx, weight_mat.Resize(dims), dev_ctx.GetPlace(), true, out); + } } } // namespace phi diff --git a/paddle/phi/ops/compat/spectral_norm_sig.cc b/paddle/phi/ops/compat/spectral_norm_sig.cc index 16705ba76390d..ea11df24881aa 100644 --- a/paddle/phi/ops/compat/spectral_norm_sig.cc +++ b/paddle/phi/ops/compat/spectral_norm_sig.cc @@ -14,19 +14,26 @@ #include "paddle/phi/core/compat/op_utils.h" -namespace phi{ +namespace phi { -KernelSignature SpectralNormOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("spectral_norm", {"Weight", "U", "V"}, {"dim", "power_iters", "eps"}, {"Out"}); +KernelSignature SpectralNormOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("spectral_norm", + {"Weight", "U", "V"}, + {"dim", "power_iters", "eps"}, + {"Out"}); } KernelSignature SpectralNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "spectral_norm_grad", {"Weight", "U", "V", "Out@GRAD"}, {"dim", "power_iters", "eps"}, {"Weight@GRAD"}); + return KernelSignature("spectral_norm_grad", + {"Weight", "U", "V", "Out@GRAD"}, + {"dim", "power_iters", "eps"}, + {"Weight@GRAD"}); } -} // namespace phi +} // namespace phi PD_REGISTER_ARG_MAPPING_FN(spectral_norm, phi::SpectralNormOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(spectral_norm_grad, phi::SpectralNormGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(spectral_norm_grad, + phi::SpectralNormGradOpArgumentMapping); From 7271d14182e1aa520f624a53f1ace309972e0936 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Wed, 27 Jul 2022 11:10:48 +0800 Subject: [PATCH 13/13] Optimize directory structure --- paddle/phi/infermeta/ternary.cc | 8 +- paddle/phi/kernels/funcs/spectral_norm.h | 106 ------------------ .../impl/spectral_norm_grad_kernel_impl.h | 2 +- .../kernels/impl/spectral_norm_kernel_impl.h | 87 +++++++++++++- 4 files changed, 92 insertions(+), 111 deletions(-) delete mode 100644 paddle/phi/kernels/funcs/spectral_norm.h diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 3b7b655bd2ee5..704f14dcbce9a 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1048,9 +1048,11 @@ void SpectralNormInferMeta(const MetaTensor& weight, w)); } - out->set_dims(dim_weight); - out->set_dtype(weight.dtype()); - out->share_lod(weight); + if (out) { + out->set_dims(dim_weight); + out->set_dtype(weight.dtype()); + out->share_lod(weight); + } } void ViterbiDecodeInferMeta(const MetaTensor& input, diff --git a/paddle/phi/kernels/funcs/spectral_norm.h b/paddle/phi/kernels/funcs/spectral_norm.h deleted file mode 100644 index 290fbf59e2939..0000000000000 --- a/paddle/phi/kernels/funcs/spectral_norm.h +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace phi { - -using Array1 = Eigen::DSizes; -using Array2 = Eigen::DSizes; -using IndexPair = Eigen::IndexPair; - -template -static inline void TransCompute2DTo5D(const Context& dev_ctx, - const DenseTensor& in, - const int rank, - const std::vector& perm, - DenseTensor* out) { - if (rank <= 1 || rank > 5) { - PADDLE_THROW(phi::errors::Fatal( - "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", - rank)); - } - - switch (rank) { - case 2: - phi::funcs::Transpose trans2; - trans2(dev_ctx, in, out, perm); - break; - case 3: - phi::funcs::Transpose trans3; - trans3(dev_ctx, in, out, perm); - break; - case 4: - phi::funcs::Transpose trans4; - trans4(dev_ctx, in, out, perm); - break; - case 5: - phi::funcs::Transpose trans5; - trans5(dev_ctx, in, out, perm); - break; - default: - break; - } -} - -template -static inline void CalcMatrixSigmaAndNormWeight(const Context& dev_ctx, - DenseTensor* weight, - DenseTensor* u, - DenseTensor* v, - DenseTensor* sigma, - const int power_iters, - const float eps) { - auto& place = *dev_ctx.eigen_device(); - auto blas = funcs::GetBlas(dev_ctx); - auto sigma_t = EigenTensor::From(*sigma); - auto weight_t = EigenTensor::From(*weight); - auto u_t = EigenTensor::From(*u); - auto v_t = EigenTensor::From(*v); - - const int h = weight->dims()[0]; - const int w = weight->dims()[1]; - - for (int i = 0; i < power_iters; i++) { - // V = W^T * U / ||W^T * U||_2 - blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); - auto v_t_norm = - v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( - Array1(w)); - v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); - // U = W^T * V / ||W^T * V||_2 - blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); - auto u_t_norm = - u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( - Array1(h)); - u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); - } - DenseTensor weight_v; - weight_v.Resize({h, 1}); - dev_ctx.template Alloc(&weight_v); - blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); - auto weight_v_t = EigenTensor::From(weight_v); - sigma_t.device(place) = (u_t * weight_v_t) - .sum() - .eval() - .reshape(Array2(1, 1)) - .broadcast(Array2(h, w)); - weight_t.device(place) = weight_t / sigma_t; -} - -} // namespace phi diff --git a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h index 76282ae4b9410..5bdb874bc89c4 100644 --- a/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/kernels/funcs/spectral_norm.h" +#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h index 70e13bb543515..57c5c69a63d61 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h @@ -14,10 +14,95 @@ #pragma once -#include "paddle/phi/kernels/funcs/spectral_norm.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using IndexPair = Eigen::IndexPair; + +template +static inline void TransCompute2DTo5D(const Context& dev_ctx, + const DenseTensor& in, + const int rank, + const std::vector& perm, + DenseTensor* out) { + if (rank <= 1 || rank > 5) { + PADDLE_THROW(phi::errors::Fatal( + "Weight rank of SpectralNorm should be in range [2, 5], but got %d.", + rank)); + } + + switch (rank) { + case 2: + phi::funcs::Transpose trans2; + trans2(dev_ctx, in, out, perm); + break; + case 3: + phi::funcs::Transpose trans3; + trans3(dev_ctx, in, out, perm); + break; + case 4: + phi::funcs::Transpose trans4; + trans4(dev_ctx, in, out, perm); + break; + case 5: + phi::funcs::Transpose trans5; + trans5(dev_ctx, in, out, perm); + break; + default: + break; + } +} + +template +static inline void CalcMatrixSigmaAndNormWeight(const Context& dev_ctx, + DenseTensor* weight, + DenseTensor* u, + DenseTensor* v, + DenseTensor* sigma, + const int power_iters, + const float eps) { + auto& place = *dev_ctx.eigen_device(); + auto blas = funcs::GetBlas(dev_ctx); + auto sigma_t = EigenTensor::From(*sigma); + auto weight_t = EigenTensor::From(*weight); + auto u_t = EigenTensor::From(*u); + auto v_t = EigenTensor::From(*v); + + const int h = weight->dims()[0]; + const int w = weight->dims()[1]; + + for (int i = 0; i < power_iters; i++) { + // V = W^T * U / ||W^T * U||_2 + blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); + auto v_t_norm = + v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(w)); + v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); + // U = W^T * V / ||W^T * V||_2 + blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); + auto u_t_norm = + u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( + Array1(h)); + u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); + } + DenseTensor weight_v; + weight_v.Resize({h, 1}); + dev_ctx.template Alloc(&weight_v); + blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); + auto weight_v_t = EigenTensor::From(weight_v); + sigma_t.device(place) = (u_t * weight_v_t) + .sum() + .eval() + .reshape(Array2(1, 1)) + .broadcast(Array2(h, w)); + weight_t.device(place) = weight_t / sigma_t; +} + template void SpectralNormKernel(const Context& dev_ctx, const DenseTensor& weight,