Skip to content

Commit

Permalink
[VJP]Fix frobenius_norm bug and add to vjp list (PaddlePaddle#57942)
Browse files Browse the repository at this point in the history
* fix frobenius_norm bug

* recover yamls and move frobenius_norm to manual

* fix and close new_ir

* fix frobenius_norm

* clean

* add ReduceInferMetaBaseFN for frobenius_norm

* clean

* fix rename bug

* clean

* use ReduceIntArrayAxisInferMetaBase instead
  • Loading branch information
changeyoung98 authored Oct 23, 2023
1 parent 46d84e5 commit 9d273aa
Show file tree
Hide file tree
Showing 12 changed files with 23 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


vjp_interface_black_list = [
'frobenius_norm',
'write_to_array',
'fused_attention',
'fused_feedforward',
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@
invoke : zeros_like(out_grad)

- backward_op : frobenius_norm_grad
forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all)
forward : frobenius_norm(Tensor x, IntArray axis, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis, bool keep_dim, bool reduce_all)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,10 @@
inplace: (x -> out)

- op : frobenius_norm
args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
args : (Tensor x, IntArray axis, bool keep_dim, bool reduce_all)
output : Tensor(out)
infer_meta :
func : ReduceInferMetaBase
func : ReduceIntArrayAxisInferMetaBase
kernel :
func : frobenius_norm
backward : frobenius_norm_grad
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1)
output : Tensor(out)
infer_meta :
func : ReduceInferMetaBase
func : ReduceIntArrayAxisInferMetaBase
kernel :
func : frobenius_norm
param : [x, axis, keepdim, reduce_all]
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3263,6 +3263,7 @@ void ReduceInferMeta(const MetaTensor& x,
if (axis.empty()) {
reduce_all = true;
}

ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out);
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/frobenius_norm_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <vector>

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
Expand All @@ -25,7 +26,7 @@ void FrobeniusNormGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* dx);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/frobenius_norm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

#include <vector>

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void FrobeniusNormKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/gpu/frobenius_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ namespace phi {
template <typename T, typename Context>
void FrobeniusNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
reduce_all = recompute_reduce_all(x, dims.GetData(), reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::SquareFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);

SqrtKernel<T, Context>(dev_ctx, *out, out);
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ void FrobeniusNormGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* dx) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all);
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, out, dout, axis, keep_dim, reduce_all, dx);
ctx, x, out, dout, axis.GetData(), keep_dim, reduce_all, dx);
}

} // namespace phi
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ namespace phi {
template <typename T, typename Context>
void FrobeniusNormKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all);
Reduce<Context, T, funcs::FrobeniusNormFunctor>(
ctx, x, reduce_all, axis, keep_dim, x.dtype(), out);
ctx, x, reduce_all, axis.GetData(), keep_dim, x.dtype(), out);
}

} // namespace phi
2 changes: 1 addition & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def frobenius_norm(input, dim=None, keepdim=False, name=None):
"The dim of frobenius norm op should be None or two elements list!"
)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if dim is None:
return _C_ops.frobenius_norm(input, [], keepdim, True)
return _C_ops.frobenius_norm(input, dim, keepdim, False)
Expand Down
6 changes: 3 additions & 3 deletions test/legacy_test/test_norm_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def setUp(self):
self.outputs = {'Out': norm}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)

def init_test_case(self):
self.shape = [2, 3, 4, 5]
Expand All @@ -126,7 +126,7 @@ def init_dtype(self):
self.dtype = "float32"

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestPnormOp(OpTest):
Expand Down

0 comments on commit 9d273aa

Please sign in to comment.