Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VJP]Fix frobenius_norm bug and add to vjp list #57942

Merged
merged 14 commits into from
Oct 23, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


vjp_interface_black_list = [
'frobenius_norm',
'write_to_array',
'fused_attention',
'fused_feedforward',
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@
invoke : zeros_like(out_grad)

- backward_op : frobenius_norm_grad
forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all)
forward : frobenius_norm(Tensor x, IntArray axis, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis, bool keep_dim, bool reduce_all)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,10 @@
inplace: (x -> out)

- op : frobenius_norm
args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
args : (Tensor x, IntArray axis, bool keep_dim, bool reduce_all)
output : Tensor(out)
infer_meta :
func : ReduceInferMetaBase
func : ReduceIntArrayAxisInferMetaBase
kernel :
func : frobenius_norm
backward : frobenius_norm_grad
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1)
output : Tensor(out)
infer_meta :
func : ReduceInferMetaBase
func : ReduceIntArrayAxisInferMetaBase
kernel :
func : frobenius_norm
param : [x, axis, keepdim, reduce_all]
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3263,6 +3263,7 @@ void ReduceInferMeta(const MetaTensor& x,
if (axis.empty()) {
reduce_all = true;
}

ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out);
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/frobenius_norm_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <vector>

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
Expand All @@ -25,7 +26,7 @@ void FrobeniusNormGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* dx);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/frobenius_norm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

#include <vector>

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void FrobeniusNormKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/gpu/frobenius_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ namespace phi {
template <typename T, typename Context>
void FrobeniusNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
reduce_all = recompute_reduce_all(x, dims.GetData(), reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::SquareFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);

SqrtKernel<T, Context>(dev_ctx, *out, out);
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ void FrobeniusNormGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* dx) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all);
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, out, dout, axis, keep_dim, reduce_all, dx);
ctx, x, out, dout, axis.GetData(), keep_dim, reduce_all, dx);
}

} // namespace phi
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ namespace phi {
template <typename T, typename Context>
void FrobeniusNormKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all);
Reduce<Context, T, funcs::FrobeniusNormFunctor>(
ctx, x, reduce_all, axis, keep_dim, x.dtype(), out);
ctx, x, reduce_all, axis.GetData(), keep_dim, x.dtype(), out);
}

} // namespace phi
2 changes: 1 addition & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def frobenius_norm(input, dim=None, keepdim=False, name=None):
"The dim of frobenius norm op should be None or two elements list!"
)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if dim is None:
return _C_ops.frobenius_norm(input, [], keepdim, True)
return _C_ops.frobenius_norm(input, dim, keepdim, False)
Expand Down
6 changes: 3 additions & 3 deletions test/legacy_test/test_norm_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def setUp(self):
self.outputs = {'Out': norm}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)

def init_test_case(self):
self.shape = [2, 3, 4, 5]
Expand All @@ -126,7 +126,7 @@ def init_dtype(self):
self.dtype = "float32"

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestPnormOp(OpTest):
Expand Down