Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support paddle.sum/mean/loss api output 0D, test=allcase #52739

Merged
merged 1 commit into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ void sum_grad(const Tensor& x,
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
Expand Down
7 changes: 2 additions & 5 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4002,9 +4002,6 @@ DDim OriginReduceInferDim(const MetaTensor& x,
out_dim_vector.push_back(x.dims().at(i));
}
}
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}

DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
Expand All @@ -4021,14 +4018,14 @@ DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x,
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
vec_dim = {};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
if (vec_axis.size() > x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
Expand Down
20 changes: 13 additions & 7 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def complete_update_annotation(self, serial_main_program):
world_ranks
)
out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape))
-1 for _ in out_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
Expand Down Expand Up @@ -1732,7 +1732,9 @@ def complete_update_annotation(self, serial_main_program):
len(out_var.shape) == 1
and out_var.shape[0] == 1
)
out_dist_attr.dims_mapping = [-1]
out_dist_attr.dims_mapping = [
-1 for _ in out_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
Expand Down Expand Up @@ -1802,16 +1804,20 @@ def complete_update_annotation(self, serial_main_program):
param.name, ref_dims_mapping
)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1 for _ in learning_var.shape]
)
op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1]
learning_var.name, [-1 for _ in learning_var.shape]
)

if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = ProcessMesh(world_ranks)
var_dist_attr.dims_mapping = [-1]
var_dist_attr.dims_mapping = [
-1 for _ in learning_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr
)
Expand Down Expand Up @@ -1841,10 +1847,10 @@ def complete_update_annotation(self, serial_main_program):
):
input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping(
input_var.name, [-1]
input_var.name, [-1 for _ in input_var.shape]
)
op_dist_attr.set_output_dims_mapping(
input_var.name, [-1]
input_var.name, [-1 for _ in input_var.shape]
)
else:
input_var_attr.dims_mapping = ref_dims_mapping
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _prepare_logger(
loss_indices = fetch_indices[group_idx]
assert len(loss_indices) <= 1
for idx in loss_indices:
logs["loss"] = outs[idx][0]
logs["loss"] = outs[idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to remove the [0]?

Copy link
Contributor Author

@zhwesky2010 zhwesky2010 Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outs是executor的fetch出来的loss结果,原来的fetch出来的是1D,所以这里取了下标0来切片;新的loss为0D:[],所以就不能取下标切片了,直接当成float来用就行。

group_idx += 1
# logging metrics
dist_context = self._dist_contexts[mode]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):

for var_name in act_grad_names:
var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
# consider that the variable's shape is None
# consider that the variable's shape is [], which is 0D
# TODO utilize the batch_dim attr instead of "0" in future
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("Ids")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
18 changes: 9 additions & 9 deletions python/paddle/distributed/auto_parallel/operators/dist_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1365,7 +1365,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1552,7 +1552,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1929,7 +1929,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -2264,7 +2264,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -2449,7 +2449,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -2832,7 +2832,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -3178,7 +3178,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down Expand Up @@ -377,7 +379,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down Expand Up @@ -637,7 +641,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,9 @@ def _complete_sub_update_program(self, sub_program_dist_context):
len(out_var.shape) == 1
and out_var.shape[0] == 1
)
out_dist_attr.dims_mapping = [-1]
out_dist_attr.dims_mapping = [
-1 for _ in out_var.shape
]
sub_program_dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
Expand Down Expand Up @@ -1798,17 +1800,19 @@ def _complete_sub_update_program(self, sub_program_dist_context):
)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1]
learning_var.name, [-1 for i in learning_var.shape]
)
op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1]
learning_var.name, [-1 for i in learning_var.shape]
)

if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = world_ranks
var_dist_attr.dims_mapping = [-1]
var_dist_attr.dims_mapping = [
-1 for i in learning_var.shape
]
sub_program_dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr
)
Expand Down
13 changes: 9 additions & 4 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0])
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
Expand All @@ -1480,7 +1481,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0])
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
assert (
dims_mapping[0] == -1
Expand All @@ -1505,7 +1507,7 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
if len(dims_mapping) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
Expand All @@ -1514,7 +1516,10 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
if (
len(dims_mapping) >= 1
and compatible_dim_mapping != dims_mapping[0]
):
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
Expand Down
Loading