Skip to content

Commit

Permalink
[Zero-Dim] support input 0D Tensor for softmax/log_softmax/gumbel_sof…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Oct 27, 2022
1 parent 4717329 commit e442f24
Show file tree
Hide file tree
Showing 21 changed files with 290 additions and 48 deletions.
8 changes: 0 additions & 8 deletions paddle/fluid/operators/randint_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,6 @@ class RandintOp : public framework::OperatorWithKernel {
return;
}

PADDLE_ENFORCE_EQ(shape.empty(),
false,
platform::errors::InvalidArgument(
"if there is no Input(ShapeTensorList) and no "
"Input(ShapeTensor),the "
"attr(shape) information must "
"be set by Attr(shape)."));

std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape.size());
for (auto dim : shape) {
Expand Down
5 changes: 0 additions & 5 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ void RandintInferMeta(
high));

auto& shape_vector = shape.GetData();
PADDLE_ENFORCE_EQ(
shape_vector.empty(),
false,
errors::InvalidArgument("The shape information should not be empty, it "
"must be set by Attr(shape)."));

std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape_vector.size());
Expand Down
72 changes: 46 additions & 26 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3108,16 +3108,29 @@ void SliceRawInferMeta(const MetaTensor& input,
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) {
auto dim_x = x.dims();
auto rank_x = dim_x.size();
PADDLE_ENFORCE_GE(axis,
-rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
if (rank_x > 0) {
PADDLE_ENFORCE_GE(axis,
-rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
} else {
PADDLE_ENFORCE_GE(
axis,
-1,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
PADDLE_ENFORCE_LE(
axis,
0,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
}

out->set_dims(x.dims());
out->set_dtype(x.dtype());
Expand Down Expand Up @@ -3963,22 +3976,29 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis,
MetaTensor* out) {
auto rank = x.dims().size();
PADDLE_ENFORCE_GE(
axis,
-rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis,
rank));
PADDLE_ENFORCE_LT(
axis,
rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis,
rank));
if (rank > 0) {
PADDLE_ENFORCE_GE(axis,
-rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
} else if (rank == 0) {
PADDLE_ENFORCE_GE(
axis,
-1,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
PADDLE_ENFORCE_LE(
axis,
0,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
}
out->share_meta(x);
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/cpu/log_softmax_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand Down Expand Up @@ -71,6 +72,11 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
const int canonical_axis = funcs::CanonicalAxis(axis, rank);

dev_ctx.template Alloc<T>(x_grad);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
if (out.numel() != 0) {
LogSoftmaxGradFunctor<Context, T>()(
dev_ctx, &out, &out_grad, x_grad, canonical_axis);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/cpu/log_softmax_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand Down Expand Up @@ -109,6 +110,11 @@ void LogSoftmaxKernel(const Context& dev_ctx,
const int canonical_axis = funcs::CanonicalAxis(axis, rank);

dev_ctx.template Alloc<T>(out);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 0.0);
return;
}
if (x.numel() != 0) {
LogSoftmaxFunctor<Context, T>()(dev_ctx, &x, out, canonical_axis);
}
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace phi {
Expand All @@ -27,6 +28,12 @@ void LogSoftmaxGradKernel(const Context &dev_ctx,
int axis,
DenseTensor *x_grad) {
dev_ctx.template Alloc<T>(x_grad);
const int rank = out.dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
phi::SoftmaxBackwardCUDAKernelDriver<T, true>(
dev_ctx, out, out_grad, axis, x_grad);
}
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/kernels/gpu/log_softmax_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace phi {
Expand All @@ -25,7 +26,14 @@ void LogSoftmaxKernel(const Context &dev_ctx,
const DenseTensor &x,
int axis,
DenseTensor *out) {
const int rank = x.dims().size();

dev_ctx.template Alloc<T>(out);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 0.0);
return;
}
phi::SoftmaxForwardCUDAKernelDriver<T, true>(dev_ctx, x, axis, out);
}

Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpudnn/softmax_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace phi {
Expand All @@ -27,6 +28,14 @@ void SoftmaxGradGPUDNNKernel(const Context& dev_ctx,
int axis,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);

const int rank = out.dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}

SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx, out, out_grad, axis, x_grad);
}

Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpudnn/softmax_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace phi {
Expand All @@ -26,6 +27,14 @@ void SoftmaxGPUDNNKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);

const int rank = x.dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 1.0);
return;
}

SoftmaxForwardCUDAKernelDriver<T>(dev_ctx, x, axis, out);
}

Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand All @@ -37,6 +38,12 @@ void GumbelSoftmaxGradKernel(const Context& ctx,
return;
}

// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(ctx, dx, 0.0);
return;
}

const int size_to_axis = funcs::SizeToAxis(axis, dx->dims());
const int size_from_axis = funcs::SizeFromAxis(axis, dx->dims());
DenseTensor dx_2d(*dx), out_2d(out), dout_2d(dout);
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand Down Expand Up @@ -67,6 +68,12 @@ void GumbelSoftmaxKernelHelper(const Context& ctx,
return;
}

// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(ctx, out, 1.0);
return;
}

const int size_to_axis = funcs::SizeToAxis(axis, x.dims());
const int size_from_axis = funcs::SizeFromAxis(axis, x.dims());
DenseTensor x_noise_2d, out_2d(*out);
Expand Down
17 changes: 12 additions & 5 deletions paddle/phi/kernels/impl/softmax_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/softmax_grad_kernel.h"

namespace phi {
Expand All @@ -26,16 +27,22 @@ void SoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
const int rank = x_grad->dims().size();
const int calc_axis = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x_grad->dims()[calc_axis];

// allocate memory on device.
dev_ctx.template Alloc<T>(x_grad);

const int rank = x_grad->dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
// For zero-sized Tensor
if (x_grad->numel() == 0) {
return;
}

const int calc_axis = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x_grad->dims()[calc_axis];

const int n = phi::funcs::SizeToAxis(calc_axis, x_grad->dims());
const int d = phi::funcs::SizeFromAxis(calc_axis, x_grad->dims());
DenseTensor dX_2d, Out_2d, dOut_2d;
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/impl/softmax_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/softmax_kernel.h"

namespace phi {
Expand All @@ -31,9 +32,15 @@ void SoftmaxKernel(const Context& dev_ctx,

// allocate memory on device.
dev_ctx.template Alloc<T>(out);
// For 0-Sized Tensor
if (out->numel() == 0) {
return;
}
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 1.0);
return;
}

const int n = phi::funcs::SizeToAxis(calc_axis, x.dims());
const int d = phi::funcs::SizeFromAxis(calc_axis, x.dims());
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand All @@ -30,6 +31,12 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
const int rank = out.dims().size();
axis = funcs::CanonicalAxis(axis, rank);

// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}

if (out.numel() != 0) {
auto out_shape = phi::vectorize<int>(out.dims());
dev_ctx.template Alloc<T>(x_grad);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/xpu/log_softmax_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand All @@ -29,6 +30,11 @@ void LogSoftmaxKernel(const Context& dev_ctx,
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);

// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 0.0);
return;
}
if (x.numel() != 0) {
auto x_shape = phi::vectorize<int>(x.dims());
dev_ctx.template Alloc<T>(out);
Expand Down
Loading

0 comments on commit e442f24

Please sign in to comment.