From e195fedd1d047c88341d36deaea800bebf28a229 Mon Sep 17 00:00:00 2001 From: zhouwei25 Date: Mon, 13 Mar 2023 08:50:58 +0000 Subject: [PATCH] [Zero-Dim] Support 0D for numel/rank/size/optimizer/create_parameter/create_global_var, fix some usage to adapt 0D --- paddle/phi/core/kernel_utils.h | 5 +- paddle/phi/infermeta/unary.cc | 51 ++++---- .../phi/kernels/cpu/reduce_sum_grad_kernel.cc | 65 ---------- .../fluid/dygraph/learning_rate_scheduler.py | 8 +- .../fluid/dygraph/varbase_patch_methods.py | 3 +- python/paddle/fluid/optimizer.py | 4 +- .../fluid/tests/book/test_fit_a_line.py | 2 +- .../tests/book/test_recommender_system.py | 2 +- .../parallel_dygraph_gradient_check.py | 2 +- .../tests/unittests/test_compare_reduce_op.py | 8 +- .../tests/unittests/test_compiled_program.py | 4 +- ...perative_star_gan_with_gradient_penalty.py | 2 +- .../fluid/tests/unittests/test_numel_op.py | 5 +- .../fluid/tests/unittests/test_size_op.py | 6 +- .../tests/unittests/test_zero_dim_tensor.py | 114 ++++++++++++++++++ python/paddle/incubate/autograd/primrules.py | 5 +- .../incubate/operators/graph_send_recv.py | 2 +- python/paddle/optimizer/optimizer.py | 4 +- .../tests/imperative_test_utils.py | 2 +- 19 files changed, 169 insertions(+), 125 deletions(-) diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index eb18d0cb98c5b..ca43a73094dbc 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -343,8 +343,9 @@ struct KernelImpl { inline bool recompute_reduce_all(const DenseTensor& x, const IntArray& dims, bool reduce_all = false) { - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size() || - reduce_all) { + if (dims.size() == 0 || x.dims().size() == 0 || + static_cast(dims.size()) == x.dims().size() || reduce_all) { + // when input 0D, it can only reduce_all return true; } else { return false; diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 24cccaf120fde..91558238d0653 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2950,24 +2950,32 @@ DDim ReduceInferDim(const MetaTensor& x, std::vector formated_axis = axis; for (size_t i = 0; i < axis.size(); ++i) { - PADDLE_ENFORCE_LT(axis[i], - x_rank, - errors::InvalidArgument( - "The reduce dim index %d should be in the " - "range [ -dimension(X), dimension(X) ) " - "which dimesion = %d. But received dim index = %d.", - i, - x_rank, - axis[i])); - PADDLE_ENFORCE_GE(axis[i], - -x_rank, - errors::InvalidArgument( - "The reduce dim index %d should be in the " - "range [ -dimension(X), dimension(X) ) " - "which dimesion = %d. But received dim index = %d.", - i, - x_rank, - axis[i])); + if (x_rank == 0) { + PADDLE_ENFORCE_EQ( + axis[i] == 0 || axis[i] == -1, + true, + phi::errors::InvalidArgument( + "When input 0D Tensor, the axis can only be -1, 0, None or []")); + } else { + PADDLE_ENFORCE_LT(axis[i], + x_rank, + errors::InvalidArgument( + "The reduce dim index %d should be in the " + "range [ -dimension(X), dimension(X) ) " + "which dimesion = %d. But received dim index = %d.", + i, + x_rank, + axis[i])); + PADDLE_ENFORCE_GE(axis[i], + -x_rank, + errors::InvalidArgument( + "The reduce dim index %d should be in the " + "range [ -dimension(X), dimension(X) ) " + "which dimesion = %d. But received dim index = %d.", + i, + x_rank, + axis[i])); + } if (axis[i] < 0) { formated_axis[i] = axis[i] + x_rank; @@ -3356,12 +3364,7 @@ void ShardIndexInferMeta(const MetaTensor& in, void NumelInferMeta(const MetaTensor& input, MetaTensor* out) { out->set_dtype(DataType::INT64); - if (input.dims().size() == 0) { - out->set_dims(phi::make_ddim({})); - } else { - // TODO(zhouwei): will change shape [1] to [] to support zero-dim - out->set_dims(phi::make_ddim({1})); - } + out->set_dims(phi::make_ddim({})); } void SliceRawInferMeta(const MetaTensor& input, diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc index e7d73611cf041..b261e610d2073 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc @@ -22,53 +22,6 @@ #include "paddle/phi/kernels/impl/reduce_grad.h" namespace phi { -template -void ComputeFromInput(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& input2, - const std::vector& dims, - DenseTensor* x_grad) { - auto* input0 = &x; - auto* output = x_grad; - dev_ctx.template Alloc(output); - - const auto* input2_d = input2.data(); - auto* output_d = output->data(); - - // handle reduce_all - if (input2.dims().size() == 1 && input2.dims()[0] == 1) { - for (int64_t i = 0; i < phi::product(input0->dims()); ++i) { - output_d[i] = input2_d[0]; - } - return; - } - - // handle reduce by one dimension - int reduce_dim_index = dims[0]; - if (reduce_dim_index < 0) { - reduce_dim_index += input0->dims().size(); - } - - auto& input_dim = input0->dims(); - int64_t before_dim = 1; - for (int i = 0; i < reduce_dim_index; ++i) { - before_dim *= input_dim[i]; - } - int64_t reduce_dim = input_dim[reduce_dim_index]; - int64_t after_dim = 1; - for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) { - after_dim *= input_dim[i]; - } - for (int64_t i = 0; i < before_dim; ++i) { - for (int64_t j = 0; j < reduce_dim; ++j) { - for (int64_t k = 0; k < after_dim; ++k) { - output_d[i * reduce_dim * after_dim + j * after_dim + k] = - input2_d[i * after_dim + k]; - } - } - } -} - template void ReduceSumGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -78,24 +31,6 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad) { reduce_all = recompute_reduce_all(x, dims, reduce_all); - if (dims.size() == 1) { - if (out_grad.dtype() != x.dtype()) { - DenseTensorMeta x_grad_meta( - out_grad.dtype(), x_grad->dims(), x_grad->layout()); - DenseTensor x_grad_tmp = - phi::Empty(dev_ctx, std::move(x_grad_meta)); - - ComputeFromInput( - dev_ctx, x, out_grad, dims.GetData(), &x_grad_tmp); - - phi::CastKernel(dev_ctx, x_grad_tmp, x.dtype(), x_grad); - - } else { - ComputeFromInput( - dev_ctx, x, out_grad, dims.GetData(), x_grad); - } - } - ReduceGradKernel(dev_ctx, x, paddle::none, diff --git a/python/paddle/fluid/dygraph/learning_rate_scheduler.py b/python/paddle/fluid/dygraph/learning_rate_scheduler.py index cf794ad4cef89..2b368190d57c0 100644 --- a/python/paddle/fluid/dygraph/learning_rate_scheduler.py +++ b/python/paddle/fluid/dygraph/learning_rate_scheduler.py @@ -14,6 +14,7 @@ import math import warnings +import numpy as np import paddle from .. import unique_name @@ -953,10 +954,9 @@ def step(self, loss): # loss must be 1-D Tensor with shape [1] check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step') - assert len(loss.shape) == 1 and loss.shape[0] == 1, ( - "the loss.shape " - "should be (1L,), but the current loss.shape is {}. Maybe that " - "you should call paddle.mean to process it first.".format( + assert np.prod(loss.shape) == 1, ( + "The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. " + "Maybe that you should call paddle.mean to process it first.".format( loss.shape ) ) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 70301d3650150..25e4531eb4efb 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -728,7 +728,8 @@ def block(self): return framework.default_main_program().global_block() def __nonzero__(self): - numel = np.prod(self.shape) + # np.prod([]) -> np.float64, so use int + numel = int(np.prod(self.shape)) assert ( numel == 1 ), "When Variable is used as the condition of if/while , Variable can only contain one element." diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 08717825f7935..5c1bd44d10ac3 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1108,8 +1108,8 @@ def backward( else: assert isinstance(callbacks, list) program = loss.block.program - assert len(loss.shape) == 1 and loss.shape[0] == 1, ( - "The loss.shape should be (1L,), but the current loss.shape is {}. " + assert np.prod(loss.shape) == 1, ( + "The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. " "Maybe that you should call paddle.mean to process the current loss.".format( loss.shape ) diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 65ec225b05ee1..36e34e85462cd 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -115,7 +115,7 @@ def train_loop(main_program): ) if avg_loss_value.dtype == numpy.uint16: avg_loss_value = convert_uint16_to_float(avg_loss_value) - if avg_loss_value[0] < 10.0: + if float(avg_loss_value) < 10.0: if save_dirname is not None: paddle.static.save_inference_model( save_dirname, diff --git a/python/paddle/fluid/tests/book/test_recommender_system.py b/python/paddle/fluid/tests/book/test_recommender_system.py index c60e8d4fba3cc..0be40deb6b553 100644 --- a/python/paddle/fluid/tests/book/test_recommender_system.py +++ b/python/paddle/fluid/tests/book/test_recommender_system.py @@ -263,7 +263,7 @@ def train_loop(main_program): ) return - if math.isnan(float(out[0])): + if math.isnan(float(out)): sys.exit("got NaN loss, training failed.") if is_local: diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py index 036b9d967e861..e9301acec75b0 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py @@ -53,7 +53,7 @@ def __init__(self, train_id): def forward(self, x): is_use = ( - paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).numpy()[0] + paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).item() and self.trainer_id == 1 ) diff --git a/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py index 27c37161a78dd..121bb98f6c291 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py @@ -105,18 +105,12 @@ def test_output(self): class TestEqualReduceAPI(unittest.TestCase): - def test_name(self): - x = paddle.assign(np.array([3, 4], dtype="int32")) - y = paddle.assign(np.array([3, 4], dtype="int32")) - out = paddle.equal_all(x, y, name='equal_res') - assert 'equal_res' in out.name - def test_dynamic_api(self): paddle.disable_static() x = paddle.ones(shape=[10, 10], dtype="int32") y = paddle.ones(shape=[10, 10], dtype="int32") out = paddle.equal_all(x, y) - assert out.numpy()[0] is np.True_ + assert out.item() is True paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/test_compiled_program.py b/python/paddle/fluid/tests/unittests/test_compiled_program.py index f698980b163f8..e3ebfe64d0872 100644 --- a/python/paddle/fluid/tests/unittests/test_compiled_program.py +++ b/python/paddle/fluid/tests/unittests/test_compiled_program.py @@ -48,7 +48,7 @@ def setUp(self): feed={"image": self.img, "label": self.label}, fetch_list=[loss.name], ) - self.loss = loss_data[0] + self.loss = float(loss_data) def test_compiled_program_base(self): with new_program_scope(): @@ -70,7 +70,7 @@ def test_compiled_program_base(self): feed={"image": self.img, "label": self.label}, fetch_list=[loss.name], ) - np.testing.assert_array_equal(loss_data[0], self.loss) + np.testing.assert_array_equal(float(loss_data), self.loss) class TestCompiledProgramError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py index 4fb2eaa1cee4a..3f1257cbd8e80 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -553,7 +553,7 @@ def run(self, image_real, label_org, label_trg): self.clear_gradients() - return g_loss.numpy()[0], d_loss.numpy()[0] + return float(g_loss), float(d_loss) class StaticGraphTrainModel: diff --git a/python/paddle/fluid/tests/unittests/test_numel_op.py b/python/paddle/fluid/tests/unittests/test_numel_op.py index a2414ed369b9b..8a90883138415 100644 --- a/python/paddle/fluid/tests/unittests/test_numel_op.py +++ b/python/paddle/fluid/tests/unittests/test_numel_op.py @@ -30,7 +30,6 @@ def setUp(self): self.inputs = { 'Input': x, } - # TODO(zhouwei): will change shape [1] to [] to support zero-dim self.outputs = {'Out': np.array([np.size(x)])} def test_check_output(self): @@ -73,10 +72,10 @@ def test_numel_static(self): ) # TODO(zhouwei): will change shape [1] to [] to support zero-dim assert np.array_equal( - res_1, np.array([np.size(input_1)]).astype("int64") + res_1, np.array(np.size(input_1)).astype("int64") ) assert np.array_equal( - res_2, np.array([np.size(input_2)]).astype("int64") + res_2, np.array(np.size(input_2)).astype("int64") ) def test_numel_imperative(self): diff --git a/python/paddle/fluid/tests/unittests/test_size_op.py b/python/paddle/fluid/tests/unittests/test_size_op.py index edea44abf0890..edef25ed7a783 100644 --- a/python/paddle/fluid/tests/unittests/test_size_op.py +++ b/python/paddle/fluid/tests/unittests/test_size_op.py @@ -33,7 +33,7 @@ def setUp(self): self.config() input = np.zeros(self.shape, dtype='bool') self.inputs = {'Input': input} - self.outputs = {'Out': np.array([np.size(input)], dtype='int64')} + self.outputs = {'Out': np.array(np.size(input), dtype='int64')} def config(self): pass @@ -85,10 +85,10 @@ def test_size_static(self): ) # TODO(zhouwei): will change shape [1] to [] to support zero-dim assert np.array_equal( - res_1, np.array([np.size(input_1)]).astype("int64") + res_1, np.array(np.size(input_1)).astype("int64") ) assert np.array_equal( - res_2, np.array([np.size(input_2)]).astype("int64") + res_2, np.array(np.size(input_2)).astype("int64") ) def test_size_imperative(self): diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 8d0b38a45ed91..a49fb6ae9b834 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -204,6 +204,20 @@ def test_dygraph_reduce(self): np.testing.assert_allclose(x.grad.numpy(), np.array(1.0)) np.testing.assert_allclose(out.grad.numpy(), np.array(1.0)) + out1 = api(x, 0) + self.assertEqual(out1.shape, []) + self.assertEqual(out1, out) + out1.backward() + + out2 = api(x, -1) + self.assertEqual(out2.shape, []) + self.assertEqual(out2, out) + out2.backward() + + if x.grad is not None: + self.assertEqual(x.grad.shape, []) + np.testing.assert_allclose(x.grad.numpy(), np.array(3.0)) + paddle.enable_static() def test_static_reduce(self): @@ -227,6 +241,12 @@ def test_static_reduce(self): out_empty_list = api(x, None) self.assertEqual(out_empty_list.shape, ()) + out1 = api(x, 0) + self.assertEqual(out1.shape, ()) + + out2 = api(x, -1) + self.assertEqual(out2.shape, ()) + fetch_list = [x, out] if block.has_var(x.grad_name): fetch_list.extend([x.grad_name, out.grad_name]) @@ -550,6 +570,16 @@ def setUp(self): paddle.disable_static() self.x = paddle.rand([]) + def test_create_parameter_var(self): + zero_dim_param = paddle.create_parameter(shape=[], dtype='float32') + self.assertEqual(zero_dim_param.shape, []) + + zero_dim_var = paddle.tensor.creation.create_global_var( + shape=[], value=0.5, dtype='float32' + ) + self.assertEqual(zero_dim_var.shape, []) + self.assertEqual(zero_dim_var.item(), 0.5) + def test_expand(self): # case1 x = paddle.full([], 1, 'float32') @@ -955,15 +985,29 @@ def test_numpy(self): np.testing.assert_array_equal(x.numpy(), np.array(0.5)) def test_numel(self): + # 1) x is 0D out = paddle.numel(self.x) self.assertEqual(out.shape, []) np.testing.assert_array_equal(out.numpy(), np.array(1)) + # 2) x is ND + x = paddle.full([3, 5], 0.5) + out = paddle.numel(x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(15)) + def test_rank(self): + # 1) x is 0D out = paddle.rank(self.x) self.assertEqual(out.shape, []) np.testing.assert_array_equal(out.numpy(), np.array(0)) + # 1) x is ND + x = paddle.full([3, 5], 0.5) + out = paddle.rank(x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(2)) + def test_shape(self): out = paddle.shape(self.x) self.assertEqual(out.shape, [0]) @@ -1878,6 +1922,23 @@ def setUp(self): paddle.enable_static() self.exe = paddle.static.Executor() + @prog_scope() + def test_create_parameter_var(self): + zero_dim_param = paddle.create_parameter(shape=[], dtype='float32') + self.assertEqual(zero_dim_param.shape, ()) + prog = paddle.static.default_startup_program() + res = self.exe.run(prog, fetch_list=[zero_dim_param]) + self.assertEqual(res[0].shape, ()) + + zero_dim_var = paddle.static.create_global_var( + shape=[], value=0.5, dtype='float32' + ) + self.assertEqual(zero_dim_var.shape, ()) + prog = paddle.static.default_startup_program() + res = self.exe.run(prog, fetch_list=[zero_dim_var]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], 0.5) + @prog_scope() def test_expand(self): x = paddle.full([], 1, 'float32') @@ -3118,6 +3179,7 @@ def test_sequence_pad(self): res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out]) self.assertEqual(res[0].shape, (3, 4, 2)) + @prog_scope() def test_prelu(self): x1 = paddle.full([], 1.0, 'float32') x1.stop_gradient = False @@ -3150,6 +3212,7 @@ def test_prelu(self): self.assertEqual(res[4].shape, ()) self.assertEqual(res[5].shape, ()) + @prog_scope() def test_static_nn_prelu(self): x1 = paddle.full([], 1.0, 'float32') x1.stop_gradient = False @@ -3210,6 +3273,57 @@ def body(i, x): self.assertEqual(res[3].shape, ()) np.testing.assert_allclose(res[3], np.array(1.0)) + @prog_scope() + def test_numel(self): + # 1) x is 0D + x = paddle.full([], 0.5) + out = paddle.numel(x) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + np.testing.assert_array_equal(res[0], np.array(1)) + + # 2) x is ND + x = paddle.full([3, 5], 0.5) + out = paddle.numel(x) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + np.testing.assert_array_equal(res[0], np.array(15)) + + @prog_scope() + def test_rank(self): + # 1) x is 0D + x = paddle.full([], 0.5) + out = paddle.rank(x) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + np.testing.assert_array_equal(res[0], np.array(0)) + + # 1) x is ND + x = paddle.full([3, 5], 0.5) + out = paddle.rank(x) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + np.testing.assert_array_equal(res[0], np.array(2)) + + @prog_scope() + def _test_shape(self): + x = paddle.full([], 0.5) + out = paddle.shape(x) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + # 0-Size should be [ np.array([]) ], its [None] now + self.assertEqual(res[0].shape, (0)) + np.testing.assert_array_equal(res[0], np.array([])) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 137747e75da15..0717b6fef92e2 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -605,10 +605,7 @@ def batch_norm_orig2prim( @REGISTER_ORIG2PRIM('size') def size_orig2prim(op, x): - # TODO(zhouwei): will change shape [1] to [] to support zero-dim - return fill_const( - functools.reduce(operator.mul, x.shape), (1,), paddle.int64 - ) + return fill_const(functools.reduce(operator.mul, x.shape), (), paddle.int64) # Register prim2orig lower rules diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 3df4ebbbe6caf..5d92a7f6d065c 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -182,7 +182,7 @@ def convert_out_size_to_list(out_size): elif isinstance(out_size, (int, np.int32, np.int64)): out_size = [out_size] else: - out_size = [out_size.numpy().astype(int)[0]] + out_size = [int(out_size)] return out_size diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 43f15b55e3a2b..1a8c51d9d8faf 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -1100,8 +1100,8 @@ def backward( else: assert isinstance(callbacks, list) program = loss.block.program - assert len(loss.shape) == 1 and loss.shape[0] == 1, ( - "The loss.shape should be (1L,), but the current loss.shape is {}. " + assert np.prod(loss.shape) == 1, ( + "The number of elements of loss should be 1, but the current loss.shape is {}, whose number of elements is not 1. " "Maybe that you should call paddle.mean to process the current loss.".format( loss.shape ) diff --git a/python/paddle/static/quantization/tests/imperative_test_utils.py b/python/paddle/static/quantization/tests/imperative_test_utils.py index 3ba7b9ffef676..a583432e39cfe 100644 --- a/python/paddle/static/quantization/tests/imperative_test_utils.py +++ b/python/paddle/static/quantization/tests/imperative_test_utils.py @@ -89,7 +89,7 @@ def train_lenet(lenet, reader, optimizer): lenet.clear_gradients() if batch_id % 100 == 0: - loss_list.append(avg_loss.numpy()[0]) + loss_list.append(float(avg_loss)) _logger.info('{}: {}'.format('loss', avg_loss.numpy())) return loss_list