Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support output 0D for numel/size/optimizer/create_parameter/create_global_var, test=allcase #51566

Merged
merged 1 commit into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,9 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
inline bool recompute_reduce_all(const DenseTensor& x,
const IntArray& dims,
bool reduce_all = false) {
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size() ||
reduce_all) {
if (dims.size() == 0 || x.dims().size() == 0 ||
static_cast<int>(dims.size()) == x.dims().size() || reduce_all) {
// when input 0D, it can only reduce_all
return true;
} else {
return false;
Expand Down
51 changes: 27 additions & 24 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2950,24 +2950,32 @@ DDim ReduceInferDim(const MetaTensor& x,

std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(
axis[i] == 0 || axis[i] == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, the axis can only be -1, 0, None or []"));
} else {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
}

if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
Expand Down Expand Up @@ -3356,12 +3364,7 @@ void ShardIndexInferMeta(const MetaTensor& in,

void NumelInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64);
if (input.dims().size() == 0) {
out->set_dims(phi::make_ddim({}));
} else {
// TODO(zhouwei): will change shape [1] to [] to support zero-dim
out->set_dims(phi::make_ddim({1}));
}
out->set_dims(phi::make_ddim({}));
}

void SliceRawInferMeta(const MetaTensor& input,
Expand Down
65 changes: 0 additions & 65 deletions paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,6 @@
#include "paddle/phi/kernels/impl/reduce_grad.h"
namespace phi {

template <typename T, typename Context>
void ComputeFromInput(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& input2,
const std::vector<int64_t>& dims,
DenseTensor* x_grad) {
auto* input0 = &x;
auto* output = x_grad;
dev_ctx.template Alloc<T>(output);

const auto* input2_d = input2.data<T>();
auto* output_d = output->data<T>();

// handle reduce_all
if (input2.dims().size() == 1 && input2.dims()[0] == 1) {
for (int64_t i = 0; i < phi::product(input0->dims()); ++i) {
output_d[i] = input2_d[0];
}
return;
}

// handle reduce by one dimension
int reduce_dim_index = dims[0];
if (reduce_dim_index < 0) {
reduce_dim_index += input0->dims().size();
}

auto& input_dim = input0->dims();
int64_t before_dim = 1;
for (int i = 0; i < reduce_dim_index; ++i) {
before_dim *= input_dim[i];
}
int64_t reduce_dim = input_dim[reduce_dim_index];
int64_t after_dim = 1;
for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) {
after_dim *= input_dim[i];
}
for (int64_t i = 0; i < before_dim; ++i) {
for (int64_t j = 0; j < reduce_dim; ++j) {
for (int64_t k = 0; k < after_dim; ++k) {
output_d[i * reduce_dim * after_dim + j * after_dim + k] =
input2_d[i * after_dim + k];
}
}
}
}

template <typename T, typename Context>
void ReduceSumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -78,24 +31,6 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (dims.size() == 1) {
if (out_grad.dtype() != x.dtype()) {
DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
DenseTensor x_grad_tmp =
phi::Empty<Context>(dev_ctx, std::move(x_grad_meta));

ComputeFromInput<T, Context>(
dev_ctx, x, out_grad, dims.GetData(), &x_grad_tmp);

phi::CastKernel<T>(dev_ctx, x_grad_tmp, x.dtype(), x_grad);

} else {
ComputeFromInput<T, Context>(
dev_ctx, x, out_grad, dims.GetData(), x_grad);
}
}

ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx,
x,
paddle::none,
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/fluid/dygraph/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import math
import warnings
import numpy as np

import paddle
from .. import unique_name
Expand Down Expand Up @@ -953,10 +954,9 @@ def step(self, loss):

# loss must be 1-D Tensor with shape [1]
check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step')
assert len(loss.shape) == 1 and loss.shape[0] == 1, (
"the loss.shape "
"should be (1L,), but the current loss.shape is {}. Maybe that "
"you should call paddle.mean to process it first.".format(
assert np.prod(loss.shape) == 1, (
"The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. "
"Maybe that you should call paddle.mean to process it first.".format(
loss.shape
)
)
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,8 @@ def block(self):
return framework.default_main_program().global_block()

def __nonzero__(self):
numel = np.prod(self.shape)
# np.prod([]) -> np.float64, so use int
numel = int(np.prod(self.shape))
assert (
numel == 1
), "When Variable is used as the condition of if/while , Variable can only contain one element."
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,8 +1108,8 @@ def backward(
else:
assert isinstance(callbacks, list)
program = loss.block.program
assert len(loss.shape) == 1 and loss.shape[0] == 1, (
"The loss.shape should be (1L,), but the current loss.shape is {}. "
assert np.prod(loss.shape) == 1, (
"The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. "
"Maybe that you should call paddle.mean to process the current loss.".format(
loss.shape
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/book/test_fit_a_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train_loop(main_program):
)
if avg_loss_value.dtype == numpy.uint16:
avg_loss_value = convert_uint16_to_float(avg_loss_value)
if avg_loss_value[0] < 10.0:
if float(avg_loss_value) < 10.0:
if save_dirname is not None:
paddle.static.save_inference_model(
save_dirname,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/book/test_recommender_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def train_loop(main_program):
)
return

if math.isnan(float(out[0])):
if math.isnan(float(out)):
sys.exit("got NaN loss, training failed.")

if is_local:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, train_id):

def forward(self, x):
is_use = (
paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).numpy()[0]
paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).item()
and self.trainer_id == 1
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,12 @@ def test_output(self):


class TestEqualReduceAPI(unittest.TestCase):
def test_name(self):
x = paddle.assign(np.array([3, 4], dtype="int32"))
y = paddle.assign(np.array([3, 4], dtype="int32"))
out = paddle.equal_all(x, y, name='equal_res')
assert 'equal_res' in out.name

def test_dynamic_api(self):
paddle.disable_static()
x = paddle.ones(shape=[10, 10], dtype="int32")
y = paddle.ones(shape=[10, 10], dtype="int32")
out = paddle.equal_all(x, y)
assert out.numpy()[0] is np.True_
assert out.item() is True
paddle.enable_static()


Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/unittests/test_compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def setUp(self):
feed={"image": self.img, "label": self.label},
fetch_list=[loss.name],
)
self.loss = loss_data[0]
self.loss = float(loss_data)

def test_compiled_program_base(self):
with new_program_scope():
Expand All @@ -70,7 +70,7 @@ def test_compiled_program_base(self):
feed={"image": self.img, "label": self.label},
fetch_list=[loss.name],
)
np.testing.assert_array_equal(loss_data[0], self.loss)
np.testing.assert_array_equal(float(loss_data), self.loss)


class TestCompiledProgramError(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def run(self, image_real, label_org, label_trg):

self.clear_gradients()

return g_loss.numpy()[0], d_loss.numpy()[0]
return float(g_loss), float(d_loss)


class StaticGraphTrainModel:
Expand Down
5 changes: 2 additions & 3 deletions python/paddle/fluid/tests/unittests/test_numel_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def setUp(self):
self.inputs = {
'Input': x,
}
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
self.outputs = {'Out': np.array([np.size(x)])}

def test_check_output(self):
Expand Down Expand Up @@ -73,10 +72,10 @@ def test_numel_static(self):
)
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64")
res_1, np.array(np.size(input_1)).astype("int64")
)
assert np.array_equal(
res_2, np.array([np.size(input_2)]).astype("int64")
res_2, np.array(np.size(input_2)).astype("int64")
)

def test_numel_imperative(self):
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/tests/unittests/test_size_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUp(self):
self.config()
input = np.zeros(self.shape, dtype='bool')
self.inputs = {'Input': input}
self.outputs = {'Out': np.array([np.size(input)], dtype='int64')}
self.outputs = {'Out': np.array(np.size(input), dtype='int64')}

def config(self):
pass
Expand Down Expand Up @@ -85,10 +85,10 @@ def test_size_static(self):
)
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64")
res_1, np.array(np.size(input_1)).astype("int64")
)
assert np.array_equal(
res_2, np.array([np.size(input_2)]).astype("int64")
res_2, np.array(np.size(input_2)).astype("int64")
)

def test_size_imperative(self):
Expand Down
Loading