Skip to content

Commit

Permalink
modife sum_as to reduce_as
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroRains committed Apr 9, 2024
1 parent 009cf48 commit 21157bc
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 83 deletions.
20 changes: 10 additions & 10 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,16 @@
func : reciprocal_grad
inplace : (out_grad -> x_grad)

- backward_op : reduce_as_grad
forward : reduce_as(Tensor x, Tensor target) -> Tensor(out)
args : (Tensor x, Tensor target, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reduce_as_grad

- backward_op : relu6_grad
forward : relu6 (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
Expand Down Expand Up @@ -2407,16 +2417,6 @@
kernel :
func : stanh_grad

- backward_op : sum_as_grad
forward : sum_as(Tensor x, Tensor target) -> Tensor(out)
args : (Tensor x, Tensor target, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sum_as_grad

- backward_op : svd_grad
forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices)
Expand Down
20 changes: 10 additions & 10 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,16 @@
inplace : (x -> out)
backward : reciprocal_grad

- op : reduce_as
args : (Tensor x, Tensor target)
output : Tensor(out)
infer_meta :
func : ReduceAsInferMeta
kernel :
func : reduce_as
data_type : x
backward : reduce_as_grad

- op : reindex_graph
args : (Tensor x, Tensor neighbors, Tensor count, Tensor hashtable_value, Tensor hashtable_index)
output : Tensor(reindex_src), Tensor(reindex_dst), Tensor(out_nodes)
Expand Down Expand Up @@ -2769,16 +2779,6 @@
func : stanh
backward : stanh_grad

- op : sum_as
args : (Tensor x, Tensor target)
output : Tensor(out)
infer_meta :
func : SumAsInferMeta
kernel :
func : sum_as
data_type : x
backward : sum_as_grad

- op : svd
args : (Tensor x, bool full_matrices = false)
output : Tensor(u), Tensor(s), Tensor(vh)
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3047,9 +3047,9 @@ void SequenceMaskInferMeta(const MetaTensor& x,
y->set_dtype(out_dtype);
}

void SumAsInferMeta(const MetaTensor& x,
const MetaTensor& target,
MetaTensor* out) {
void ReduceAsInferMeta(const MetaTensor& x,
const MetaTensor& target,
MetaTensor* out) {
DataType out_dtype;
if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32) {
out_dtype = DataType::INT64;
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,9 @@ void ShuffleBatchInferMeta(const MetaTensor& x,

);

void SumAsInferMeta(const MetaTensor& x,
const MetaTensor& target,
MetaTensor* out);
void ReduceAsInferMeta(const MetaTensor& x,
const MetaTensor& target,
MetaTensor* out);

void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sum_as_kernel.h"
#include "paddle/phi/kernels/reduce_as_kernel.h"

#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -21,11 +21,11 @@
namespace phi {

template <typename T, typename Context>
void SumAsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
void ReduceAsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto reduce_dim = phi::funcs::GetReduceDims(x, target);
bool reduce_all = recompute_reduce_all(x, reduce_dim);
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx,
Expand All @@ -40,10 +40,10 @@ void SumAsGradKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(sum_as_grad,
PD_REGISTER_KERNEL(reduce_as_grad,
CPU,
ALL_LAYOUT,
phi::SumAsGradKernel,
phi::ReduceAsGradKernel,
bool,
float,
double,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sum_as_kernel.h"
#include "paddle/phi/kernels/reduce_as_kernel.h"

#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -21,10 +21,10 @@
namespace phi {

template <typename T, typename Context>
void SumAsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
DenseTensor* out) {
void ReduceAsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
DenseTensor* out) {
auto reduce_dim = phi::funcs::GetReduceDims(x, target);
bool reduce_all = recompute_reduce_all(x, reduce_dim);
phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
Expand All @@ -33,10 +33,10 @@ void SumAsKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(sum_as,
PD_REGISTER_KERNEL(reduce_as,
CPU,
ALL_LAYOUT,
phi::SumAsKernel,
phi::ReduceAsKernel,
bool,
float,
double,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sum_as_grad_kernel.h"
#include "paddle/phi/kernels/reduce_as_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -22,11 +22,11 @@
namespace phi {

template <typename T, typename Context>
void SumAsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
void ReduceAsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto reduce_dim = phi::funcs::GetReduceDims(x, target);
bool reduce_all = recompute_reduce_all(x, reduce_dim);
auto update_dims = common::vectorize(x.dims());
Expand All @@ -50,10 +50,10 @@ void SumAsGradKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(sum_as_grad,
PD_REGISTER_KERNEL(reduce_as_grad,
GPU,
ALL_LAYOUT,
phi::SumAsGradKernel,
phi::ReduceAsGradKernel,
bool,
float,
double,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sum_as_kernel.h"
#include "paddle/phi/kernels/reduce_as_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -21,21 +21,21 @@
namespace phi {

template <typename T, typename Context>
void SumAsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
DenseTensor* out) {
void ReduceAsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
DenseTensor* out) {
auto reduce_dim = phi::funcs::GetReduceDims(x, target);
dev_ctx.template Alloc<T>(out);
phi::SumKernel<T, Context>(dev_ctx, x, reduce_dim, out->type(), false, out);
}

} // namespace phi

PD_REGISTER_KERNEL(sum_as,
PD_REGISTER_KERNEL(reduce_as,
GPU,
ALL_LAYOUT,
phi::SumAsKernel,
phi::ReduceAsKernel,
bool,
float,
double,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
namespace phi {

template <typename T, typename Context>
void SumAsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
DenseTensor* out);
void ReduceAsKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
DenseTensor* out);

} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
namespace phi {

template <typename T, typename Context>
void SumAsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
const DenseTensor& out_grad,
DenseTensor* x_grad);
void ReduceAsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& target,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
4 changes: 2 additions & 2 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@
prod,
rad2deg,
reciprocal,
reduce_as,
remainder,
remainder_,
renorm,
Expand All @@ -494,7 +495,6 @@
stanh,
subtract,
sum,
sum_as,
take,
tan,
tan_,
Expand Down Expand Up @@ -847,7 +847,7 @@
'ones',
'not_equal',
'sum',
'sum_as',
'reduce_as',
'nansum',
'nanmean',
'count_nonzero',
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@
rad2deg,
reciprocal,
reciprocal_,
reduce_as,
remainder,
remainder_,
renorm,
Expand All @@ -383,7 +384,6 @@
subtract,
subtract_,
sum,
sum_as,
take,
tan,
tan_,
Expand Down Expand Up @@ -526,7 +526,7 @@
'square',
'stanh',
'sum',
'sum_as',
'reduce_as',
'multigammaln',
'multigammaln_',
'nan_to_num',
Expand Down
14 changes: 7 additions & 7 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
return out


def sum_as(x, target, name=None):
def reduce_as(x, target, name=None):
"""
Computes the sum of tensor elements make the shape of its result equal to the shape of target.
Expand All @@ -1601,14 +1601,14 @@ def sum_as(x, target, name=None):
>>> target
Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[1, 2, 3, 4])
>>> res = paddle.sum_as(x, target)
>>> res = paddle.reduce_as(x, target)
>>> res
Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[6 , 8 , 10, 12])
"""

if in_dynamic_or_pir_mode():
return _C_ops.sum_as(x, target)
return _C_ops.reduce_as(x, target)
else:
check_variable_and_dtype(
x,
Expand All @@ -1623,7 +1623,7 @@ def sum_as(x, target, name=None):
'int32',
'int64',
],
'sum_as',
'reduce_as',
)
check_variable_and_dtype(
target,
Expand All @@ -1638,13 +1638,13 @@ def sum_as(x, target, name=None):
'int32',
'int64',
],
'sum_as',
'reduce_as',
)

helper = LayerHelper('sum_as', **locals())
helper = LayerHelper('reduce_as', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='sum_as',
type='reduce_as',
inputs={'x': x, 'target': target},
outputs={'out': out},
)
Expand Down
Loading

0 comments on commit 21157bc

Please sign in to comment.