Skip to content

Commit

Permalink
[Zero-Dim] Support 0D Tensor input for topk/broadcast_to/expand/expan…
Browse files Browse the repository at this point in the history
…d_as/broadcast_shape (#50536)
  • Loading branch information
yunyaoXYY authored Feb 24, 2023
1 parent 4a0855a commit 5041158
Show file tree
Hide file tree
Showing 11 changed files with 578 additions and 137 deletions.
39 changes: 25 additions & 14 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ void CumInferMeta(const MetaTensor& x,
out->set_dims(x_dims);
out->set_dtype(x.dtype());
}

out->share_lod(x);
}

Expand Down Expand Up @@ -970,7 +971,7 @@ void ExpandInferMeta(const MetaTensor& x,
MAX_RANK_SUPPORTED));
PADDLE_ENFORCE_GE(
expand_shape.size(),
1,
0,
phi::errors::InvalidArgument("The number of elements (%d) of 'shape' for "
"must be a positive integer.",
expand_shape.size()));
Expand Down Expand Up @@ -1005,7 +1006,7 @@ void ExpandInferMeta(const MetaTensor& x,

out->set_dims(make_ddim(out_shape));
out->set_dtype(x.dtype());
if (out_shape[0] == x_dims[0]) {
if (out_rank > 0 && out_shape[0] == x_dims[0]) {
out->share_lod(x);
}
}
Expand Down Expand Up @@ -4097,14 +4098,23 @@ void TopKInferMeta(const MetaTensor& x,
MetaConfig config) {
auto input_dims = x.dims();
const int& dim_size = input_dims.size();
PADDLE_ENFORCE_EQ(
(axis < dim_size) && (axis >= (-1 * dim_size)),
true,
phi::errors::InvalidArgument(
"the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size,
dim_size,
axis));
if (dim_size != 0) {
PADDLE_ENFORCE_EQ(
(axis < dim_size) && (axis >= (-1 * dim_size)),
true,
phi::errors::InvalidArgument(
"the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size,
dim_size,
axis));
} else {
PADDLE_ENFORCE_EQ(
(axis == dim_size) || (axis == -1),
true,
phi::errors::InvalidArgument("the axis of topk must be 0 or -1 when "
"x.dims() = 0, but you set axis is %d",
axis));
}

if (axis < 0) axis += dim_size;

Expand All @@ -4122,12 +4132,13 @@ void TopKInferMeta(const MetaTensor& x,

PADDLE_ENFORCE_GE(
input_dims.size(),
1,
phi::errors::InvalidArgument("input of topk must have >= 1d shape"));
0,
phi::errors::InvalidArgument("input of topk must have >= 0d shape"));

phi::DDim dims = input_dims;

dims[axis] = k;
if (input_dims.size() > 0) {
dims[axis] = k;
}
out->set_dims(dims);
out->share_lod(x);
out->set_dtype(x.dtype());
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/cpu/top_k_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ void TopkGradKernel(const Context& dev_ctx,
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
return;
}

if (axis + 1 == in_dims.size()) {
// allocate the memory for the input_grad

Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/kernels/cpu/top_k_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,13 @@ void TopkKernel(const Context& dev_ctx,
const auto* input = &x;
// Get the top k elements of each row of input tensor
const auto& in_dims = input->dims();

// 0d input x
if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
dev_ctx.template Alloc<int64_t>(indices);
phi::funcs::set_constant(dev_ctx, indices, 0.0);
return;
}
// axis < 0, cacluate the real axis
if (axis < 0) {
axis += in_dims.size();
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/gpu/top_k_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ void TopkGradKernel(const Context& dev_ctx,
const T* out_grad_data = out_grad.data<T>();
const int64_t* indices_data = indices.data<int64_t>();

if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
return;
}

int pre, n, post;
phi::funcs::GetDims(in_dims, axis, &pre, &n, &post);

Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/kernels/gpu/top_k_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ void TopkKernel(const Context& dev_ctx,
const auto* input = &x;
// get the input dims
const auto& in_dims = input->dims();

// 0d input tensor
if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
dev_ctx.template Alloc<int64_t>(indices);
phi::funcs::set_constant(dev_ctx, indices, 0.0);
return;
}
// calcluate the real axis
if (axis < 0) axis += in_dims.size();

Expand Down
112 changes: 55 additions & 57 deletions paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ void ExpandAsGradKernel(const Context& context,
const std::vector<int>& target_shape,
DenseTensor* in_grad) {
auto x_dims = x.dims();

if (in_grad->dims() == out_grad.dims()) {
phi::Copy(context, out_grad, context.GetPlace(), false, in_grad);
return;
}

auto vec_in_dims = phi::vectorize<int>(x_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
Expand All @@ -65,64 +71,56 @@ void ExpandAsGradKernel(const Context& context,
}

int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;

PADDLE_ENFORCE_GE(
dims,
0,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 0, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(
dims,
MAX_RANK_SUPPORTED,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
switch (dims) {
case 0:
ExpandAsBackward<Context, T, 0>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
}
}
// no need reduce, just copy
if (just_copy) {
context.template Alloc<T>(in_grad);
phi::Copy(context, out_grad, context.GetPlace(), false, in_grad);
} else {
PADDLE_ENFORCE_GE(
dims,
1,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
switch (dims) {
case 1:
ExpandAsBackward<Context, T, 1>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 2:
ExpandAsBackward<Context, T, 2>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 3:
ExpandAsBackward<Context, T, 3>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 4:
ExpandAsBackward<Context, T, 4>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 5:
ExpandAsBackward<Context, T, 5>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 6:
ExpandAsBackward<Context, T, 6>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
case 1:
ExpandAsBackward<Context, T, 1>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 2:
ExpandAsBackward<Context, T, 2>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 3:
ExpandAsBackward<Context, T, 3>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 4:
ExpandAsBackward<Context, T, 4>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 5:
ExpandAsBackward<Context, T, 5>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 6:
ExpandAsBackward<Context, T, 6>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}

Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/kernels/impl/expand_as_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ void ExpandAs(const Context& context,
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
if (Rank == 0) {
phi::Copy<Context>(context, x, context.GetPlace(), false, out);
return;
}
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(
target_shape[i],
Expand Down Expand Up @@ -108,7 +112,7 @@ void ExpandAsKernel(const Context& ctx,
rank));
PADDLE_ENFORCE_GE(
rank,
1,
0,
errors::InvalidArgument("The rank (%d) of the input 'x' for "
"expand_as_v2 op must be positive.",
rank));
Expand All @@ -133,6 +137,9 @@ void ExpandAsKernel(const Context& ctx,
}

switch (target_rank) {
case 0:
ExpandAs<Context, T, 0>(ctx, x, real_target_shape, out);
break;
case 1:
ExpandAs<Context, T, 1>(ctx, x, real_target_shape, out);
break;
Expand Down
Loading

0 comments on commit 5041158

Please sign in to comment.