Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] fix reduce all and not keep dims case #53000

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
};

DECLARE_INFER_SHAPE_FUNCTOR(
reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::MeanRawInferMeta));

REGISTER_OPERATOR(reduce_mean,
ops::ReduceBaseOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ void sum_grad(const Tensor& x,
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
func : MeanInferMeta
kernel :
func : mean
backward : mean_grad
Expand Down
162 changes: 144 additions & 18 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3125,22 +3125,6 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
if (config.is_runtime || !axis.FromTensor()) {
ReduceInferMetaBase(x, axis.GetData(), keep_dim, reduce_all, out);
} else {
DDim out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
}

void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
Expand All @@ -3153,6 +3137,23 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims();
if (dim[0] > 0 || dim[0] < -1) {
Expand Down Expand Up @@ -3951,6 +3952,100 @@ void StridedSliceInferMeta(const MetaTensor& x,
x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config);
}

DDim ReduceSumMeanInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all) {
auto x_rank = x.dims().size();

std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(
axis[i] == 0 || axis[i] == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, the axis can only be -1, 0, None or []"));
} else {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
}

if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}

bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x_rank; ++i) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = reduce_all || full_dim;

std::vector<int64_t> out_dim_vector;
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
continue;
}
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}

DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}

DDim ReduceSumMeanInferDimForIntArrayAxis(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all) {
std::vector<int64_t> vec_axis = axis.GetData();
std::vector<int64_t> vec_dim;
if (reduce_all) {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() > x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
}
}
}
return phi::make_ddim(vec_dim);
}

/* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of
ops.yaml
Expand All @@ -3977,9 +4072,10 @@ void SumRawInferMeta(const MetaTensor& x,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
out_dim = ReduceSumMeanInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out_dim =
ReduceSumMeanInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}

DataType out_dtype;
Expand All @@ -3998,6 +4094,36 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void MeanInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
MeanRawInferMeta(x, axis, keep_dim, reduce_all, out, config);
}

void MeanRawInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceSumMeanInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim =
ReduceSumMeanInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,19 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void MeanInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config = MetaConfig());

void MeanRawInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ PD_REGISTER_KERNEL(add_n,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}

PD_REGISTER_KERNEL(add_n_array,
Expand All @@ -99,4 +100,5 @@ PD_REGISTER_KERNEL(add_n_array,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/funcs/selected_rows_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::float16>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;

#ifdef PADDLE_WITH_XPU
Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/kernels/funcs/unsqueeze.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,32 +105,32 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,

inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
const DDim& in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
int output_rank = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_rank = in_dims.size();
std::vector<int64_t> output_shape(output_rank, 0);

// Validity Check: rank range.
PADDLE_ENFORCE_LE(
output_size,
output_rank,
6,
phi::errors::InvalidArgument("The output "
"tensor's rank should be less than 6."));

for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
int cur = axis < 0 ? axis + cur_output_rank + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE_GE(
cur,
0,
phi::errors::InvalidArgument("The insert dimension value should "
"not be less than 0"));
PADDLE_ENFORCE_LE(cur,
cur_output_size,
cur_output_rank,
phi::errors::InvalidArgument(
"The insert dimension value shoule not be larger "
"than the dimension size of input tensor"));
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
for (int i = cur_output_rank; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
Expand All @@ -139,11 +139,11 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
cur_output_rank++;
}

// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
for (int in_idx = 0, out_idx = 0; out_idx < output_rank; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/onednn/reduce_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ void ReduceKernel(const Context& dev_ctx,
reduction_p->execute(astream, reduction_args);
astream.wait();

out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
const auto reshape_dims = out->dims().size() != 0
? vectorize<int64_t>(out->dims())
: std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
}
}

Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/sparse/cpu/unary_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ void DivScalarCooKernel(const Context& dev_ctx,
float scalar,
SparseCooTensor* out) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);

auto eigen_out =
phi::EigenVector<T>::Flatten(*(out->mutable_non_zero_elements()));
auto eigen_x = phi::EigenVector<T>::Flatten(x.non_zero_elements());
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/cost/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License

from collections import OrderedDict
from functools import reduce

import numpy as np

import paddle
from paddle.utils.flops import flops
Expand Down Expand Up @@ -807,7 +808,7 @@ def comm_count(self):
factor = 8
else:
raise ValueError(f"Unsupported comm dtype {dtype}")
comm_count = reduce(lambda x, y: x * y, shape) * factor
comm_count = int(np.prod(shape)) * factor
self._comm_count = comm_count

return self._comm_count
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def _prepare_logger(
loss_indices = fetch_indices[group_idx]
assert len(loss_indices) <= 1
for idx in loss_indices:
logs["loss"] = outs[idx][0]
logs["loss"] = outs[idx]
group_idx += 1
# logging metrics
dist_context = self._dist_contexts[mode]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):

for var_name in act_grad_names:
var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
# consider that the variable's shape is None
# consider that the variable's shape is [], which is 0D
# TODO utilize the batch_dim attr instead of "0" in future
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
Expand Down
Loading