diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 30a8d77ca1d3f..c4ffc3e7de204 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -123,15 +123,29 @@ static PyObject* tensor_method_numpy(TensorObject* self, size_t py_rank = tensor_dims.size(); size_t numel = 1; if (py_rank == 0) { - // 0D Tensor hack process to 1D numpy, will remove in future - VLOG(0) << "Warning:: 0D Tensor cannot be used as Tensor.numpy()[0], Now " - "0D will be changed to 1D numpy to avoid this problem, but it's " - "not correct and will be removed in future. Please change " - "'Tensor.numpy()[0]' to 'float(Tensor)' or " - "'Tensor.numpy().item()' as soon as possible."; - py_rank = 1; - py_dims[0] = 1; - py_strides[0] = sizeof_dtype * numel; + Py_ssize_t args_num = PyTuple_Size(args); + bool set_to_1d = true; + if (args_num == (Py_ssize_t)1) { + PyObject* obj = PyTuple_GET_ITEM(args, 0); + if (obj == Py_False) { + set_to_1d = false; + } + } + if (set_to_1d) { + // 0D Tensor hack process to 1D numpy, will remove in future + VLOG(0) + << "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In " + "order to avoid this problem, " + "0D Tensor will be changed to 1D numpy currently, but it's not " + "correct and will be " + "removed in future. Please modify " + " 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as " + "possible, " + "otherwise 'Tensor.numpy()[0]' will raise error"; + py_rank = 1; + py_dims[0] = 1; + py_strides[0] = sizeof_dtype * numel; + } } else { for (int i = tensor_dims.size() - 1; i >= 0; --i) { py_dims[i] = static_cast(tensor_dims[i]); @@ -143,7 +157,7 @@ static PyObject* tensor_method_numpy(TensorObject* self, PyObject* array = api.PyArray_NewFromDescr_( api.PyArray_Type_, api.PyArray_DescrFromType_(numpy_dtype), - tensor_dims.size(), + py_rank, py_dims, py_strides, nullptr, diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 4e95e5a8ea89f..0ee0c4f759d4f 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -415,9 +415,9 @@ def _broadcast_final_loss(self): ), "train_batch() in last stage should obtain vaild loss" loss = self.total_loss.detach() is_fp32 = ( - paddle.to_tensor(1) + paddle.full([], 1, 'int64') if loss.dtype == paddle.float32 - else paddle.to_tensor(0) + else paddle.full([], 0, 'int64') ) paddle.distributed.broadcast( is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group @@ -426,7 +426,7 @@ def _broadcast_final_loss(self): loss, src=self.global_rank, sync_op=True, group=self.pp_group ) else: - is_fp32 = paddle.to_tensor(1) + is_fp32 = paddle.full([], 1, 'int64') paddle.distributed.broadcast( is_fp32, src=self._hcg.get_rank_from_stage(self.num_stages - 1), @@ -435,7 +435,7 @@ def _broadcast_final_loss(self): ) loss = ( paddle.zeros(shape=[1], dtype="float32") - if is_fp32.numpy()[0] + if is_fp32.item() else paddle.zeros(shape=[1], dtype="float16") ) paddle.distributed.broadcast( diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 70301d3650150..2f00f1751f88e 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -764,7 +764,7 @@ def __array__(self, dtype=None): print(type(x_array)) # print(x_array.shape) #(2, 2) """ - array = self.numpy() + array = self.numpy(False) if dtype: array = array.astype(dtype) return array diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ac2dde6ba6998..628f9ba68a727 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -773,7 +773,7 @@ def _var_base_to_np(var_base): "paddle.fluid.framework._var_base_to_np is deprecated, please use var_base.numpy() instead of _var_base_to_np(var_base)." ) - return var_base.numpy() + return var_base.numpy(False) def _cpu_num(): diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f2f763e3b60f0..f01dd7e60ca07 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -698,10 +698,10 @@ def unsqueeze(input, axes, name=None): if isinstance(axes, int): axes = [axes] elif isinstance(axes, Variable): - axes = axes.numpy().tolist() + axes = axes.tolist() elif isinstance(axes, (list, tuple)): axes = [ - item.numpy().item(0) if isinstance(item, Variable) else item + item.item(0) if isinstance(item, Variable) else item for item in axes ] return _C_ops.unsqueeze(input, axes) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index e1051db52b41a..724d636c37306 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2432,12 +2432,12 @@ def _append_optimize_op(self, block, param_and_grad): _beta1 = ( self._beta1 if not isinstance(self._beta1, Variable) - else self._beta1.numpy().item(0) + else self._beta1.item(0) ) _beta2 = ( self._beta2 if not isinstance(self._beta2, Variable) - else self._beta2.numpy().item(0) + else self._beta2.item(0) ) master_weight = None _, _, _, _, _, _ = _legacy_C_ops.adam( diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 73dffe4f993fd..cc792d92aae56 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -952,7 +952,15 @@ def test_tolist(self): def test_numpy(self): x = paddle.full([], 0.5) - np.testing.assert_array_equal(x.numpy(), np.array(0.5)) + # 0D Tensor hack to 1D Numpy defaut, will remove in future + x_np = x.numpy() + np.testing.assert_array_equal(x_np.shape, (1,)) + np.testing.assert_array_equal(x_np, np.array([0.5])) + + # return origin correct numpy + x_np = x.numpy(False) + np.testing.assert_array_equal(x_np.shape, ()) + np.testing.assert_array_equal(x_np, np.array(0.5)) def test_numel(self): out = paddle.numel(self.x) diff --git a/python/paddle/geometric/message_passing/utils.py b/python/paddle/geometric/message_passing/utils.py index 8172add83f4f8..5122a4f110ffe 100644 --- a/python/paddle/geometric/message_passing/utils.py +++ b/python/paddle/geometric/message_passing/utils.py @@ -29,7 +29,7 @@ def convert_out_size_to_list(out_size): elif isinstance(out_size, (int, np.int32, np.int64)): out_size = [out_size] else: - out_size = [out_size.numpy().astype(int)[0]] + out_size = [int(out_size)] return out_size diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 3df4ebbbe6caf..5d92a7f6d065c 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -182,7 +182,7 @@ def convert_out_size_to_list(out_size): elif isinstance(out_size, (int, np.int32, np.int64)): out_size = [out_size] else: - out_size = [out_size.numpy().astype(int)[0]] + out_size = [int(out_size)] return out_size diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 6204a4fdbb82b..dc9250d71005f 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -503,7 +503,7 @@ def _is_list_or_turple_(data): for i, dim in enumerate(out_shape): if isinstance(dim, Variable): - out_shape[i] = dim.numpy().item() + out_shape[i] = dim.item() if not (_is_list_or_turple_(out_shape)): raise TypeError("size should be a list or tuple or Variable.") # Validate the shape @@ -1693,7 +1693,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None): if in_dygraph_mode(): if isinstance(pad, Variable): - pad = pad.numpy().tolist() + pad = pad.tolist() out = _C_ops.pad3d(x, pad, mode, value, data_format) else: attrs = {'mode': mode, 'value': value, 'data_format': data_format} diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 1acf1bcbc5f74..10b0e46efe08a 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -86,17 +86,13 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): if in_dygraph_mode(): _out_shape = ( - out_shape.numpy().tolist() - if isinstance(out_shape, Variable) - else out_shape + out_shape.tolist() if isinstance(out_shape, Variable) else out_shape ) theta = theta._use_gpudnn(use_cudnn) return _C_ops.affine_grid(theta, _out_shape, align_corners) elif in_dynamic_mode(): _out_shape = ( - out_shape.numpy().tolist() - if isinstance(out_shape, Variable) - else out_shape + out_shape.tolist() if isinstance(out_shape, Variable) else out_shape ) return _legacy_C_ops.affine_grid( theta, diff --git a/python/paddle/nn/initializer/assign.py b/python/paddle/nn/initializer/assign.py index 3ab5a896e463a..3e50fe9a6f6c2 100644 --- a/python/paddle/nn/initializer/assign.py +++ b/python/paddle/nn/initializer/assign.py @@ -211,6 +211,6 @@ def __init__(self, value, name=None): # TODO: value is already is a tensor, accounting efficiency maybe it does not need to convert tensor to numpy data and then initialized. if isinstance(value, paddle.static.Variable): - value = value.numpy() + value = value.numpy(False) super().__init__(value) diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 3fd3677c31754..59f1604f5c47f 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -310,12 +310,12 @@ def _append_optimize_op(self, block, param_and_grad): _beta1 = ( self._beta1 if not isinstance(self._beta1, Variable) - else self._beta1.numpy().item(0) + else self._beta1.item(0) ) _beta2 = ( self._beta2 if not isinstance(self._beta2, Variable) - else self._beta2.numpy().item(0) + else self._beta2.item(0) ) _, _, _, _, _, _ = _C_ops.adam_( @@ -623,12 +623,12 @@ def _append_optimize_multi_tensor_op( _beta1 = ( self._beta1 if not isinstance(self._beta1, Variable) - else self._beta1.numpy().item(0) + else self._beta1.item(0) ) _beta2 = ( self._beta2 if not isinstance(self._beta2, Variable) - else self._beta2.numpy().item(0) + else self._beta2.item(0) ) if framework.in_dygraph_mode(): diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 0233ad1c972c1..084450b8cd996 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -434,12 +434,12 @@ def _append_optimize_op(self, block, param_and_grad): _beta1 = ( self._beta1 if not isinstance(self._beta1, Variable) - else self._beta1.numpy().item(0) + else self._beta1.item(0) ) _beta2 = ( self._beta2 if not isinstance(self._beta2, Variable) - else self._beta2.numpy().item(0) + else self._beta2.item(0) ) _, _, _, _, _, _ = _C_ops.adamw_( diff --git a/python/paddle/quantization/imperative/ptq_quantizer.py b/python/paddle/quantization/imperative/ptq_quantizer.py index 41b0be44a7530..97d61ebbf0f34 100644 --- a/python/paddle/quantization/imperative/ptq_quantizer.py +++ b/python/paddle/quantization/imperative/ptq_quantizer.py @@ -24,7 +24,7 @@ def abs_max_value(tensor): - return float(paddle.max(paddle.abs(tensor)).numpy()) + return float(paddle.max(paddle.abs(tensor))) def merge_max_value(old, new): diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 002684ceb349c..479f9123008a2 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -466,7 +466,7 @@ def body(i, ten): ) if _non_static_mode(): - now_cond = pre_cond.numpy().item() + now_cond = pre_cond.item() while now_cond: output_vars = body(*loop_vars) if not isinstance(output_vars, (list, tuple)): @@ -476,7 +476,7 @@ def body(i, ten): "body in while_loop should return the same arity " "(length and structure) and types as loop_vars" ) - now_cond = cond(*output_vars).numpy().item() + now_cond = cond(*output_vars).item() map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars) return loop_vars @@ -968,7 +968,7 @@ def false_func(): if _non_static_mode(): assert isinstance(pred, Variable), "The pred in cond must be Variable" assert pred.size == 1, "condition input's numel should be 1" - pred = pred.numpy().item() + pred = pred.item() if pred: if true_fn is not None: if not callable(true_fn): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 848d73d80f7ea..ff0747a9388c4 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -329,9 +329,7 @@ def slice(input, axes, starts, ends): if isinstance(starts, (list, tuple)): starts = [ - item.numpy().item(0) - if isinstance(item, tmp_tensor_type) - else item + item.item(0) if isinstance(item, tmp_tensor_type) else item for item in starts ] elif isinstance(starts, tmp_tensor_type): @@ -341,9 +339,7 @@ def slice(input, axes, starts, ends): if isinstance(ends, (list, tuple)): ends = [ - item.numpy().item(0) - if isinstance(item, tmp_tensor_type) - else item + item.item(0) if isinstance(item, tmp_tensor_type) else item for item in ends ] elif isinstance(ends, tmp_tensor_type): @@ -1068,7 +1064,8 @@ def tolist(x): print(expectlist) #[0, 1, 2, 3, 4] """ - return x.numpy().tolist() + # TODO(zhouwei): will remove 0D Tensor.numpy() hack + return x.numpy(False).tolist() def concat(x, axis=0, name=None): @@ -1117,7 +1114,6 @@ def concat(x, axis=0, name=None): input = x if in_dygraph_mode(): if isinstance(axis, Variable): - axis = axis.numpy() axis = axis.item(0) if not isinstance(input, Variable): input = [t for t in input if t.shape.count(0) == 0] @@ -1952,7 +1948,6 @@ def split(x, num_or_sections, axis=0, name=None): dim = axis if in_dygraph_mode(): if isinstance(dim, Variable): - dim = dim.numpy() dim = dim.item(0) assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0" dim = (len(input.shape) + dim) if dim < 0 else dim @@ -1961,9 +1956,7 @@ def split(x, num_or_sections, axis=0, name=None): if paddle.utils._contain_var(num_or_sections): for index, item in enumerate(num_or_sections): if isinstance(item, Variable): - num_or_sections[index] = num_or_sections[index].numpy()[ - 0 - ] + num_or_sections[index] = num_or_sections[index].item() elif not isinstance(num_or_sections, int): raise TypeError( "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but " @@ -2593,10 +2586,10 @@ def unsqueeze(x, axis, name=None): if isinstance(axes, int): axes = [axes] elif isinstance(axes, Variable): - axes = axes.numpy().tolist() + axes = axes.tolist() elif isinstance(axes, (list, tuple)): axes = [ - item.numpy().item(0) if isinstance(item, Variable) else item + item.item(0) if isinstance(item, Variable) else item for item in axes ] return _C_ops.unsqueeze(input, axes) @@ -2659,10 +2652,10 @@ def unsqueeze_(x, axis, name=None): if isinstance(axes, int): axes = [axes] elif isinstance(axes, Variable): - axes = axes.numpy().tolist() + axes = axes.tolist() elif isinstance(axes, (list, tuple)): axes = [ - item.numpy().item(0) if isinstance(item, Variable) else item + item.item(0) if isinstance(item, Variable) else item for item in axes ] return _C_ops.unsqueeze_(input, axes) @@ -3148,7 +3141,7 @@ def tile(x, repeat_times, name=None): assert ( repeat_times.ndim == 1 ), "Only support ndim == 1 while repeat_times is a Tensor." - repeat_times = repeat_times.numpy().tolist() + repeat_times = repeat_times.tolist() return _C_ops.tile(x, repeat_times) else: @@ -3648,9 +3641,7 @@ def reshape_(x, shape, name=None): tmp_tensor_type = core.eager.Tensor if isinstance(shape, (list, tuple)): shape = [ - item.numpy().item(0) - if isinstance(item, tmp_tensor_type) - else item + item.item(0) if isinstance(item, tmp_tensor_type) else item for item in shape ] if shape == x.shape: diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e254884c3eacf..cf63a80c98334 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2865,9 +2865,9 @@ def clip(x, min=None, max=None, name=None): if in_dygraph_mode(): if isinstance(min, Variable): - min = min.numpy().item(0) + min = min.item(0) if isinstance(max, Variable): - max = max.numpy().item(0) + max = max.item(0) min = min_ if min is None else min max = max_ if max is None else max return _C_ops.clip(x, min, max) @@ -2932,9 +2932,9 @@ def clip_(x, min=None, max=None, name=None): fmin = float(np.finfo(np.float32).min) fmax = float(np.finfo(np.float32).max) if isinstance(min, Variable): - min = min.numpy().item(0) + min = min.item(0) if isinstance(max, Variable): - max = max.numpy().item(0) + max = max.item(0) min = fmin if min is None else min max = fmax if max is None else max diff --git a/python/paddle/tensor/to_string.py b/python/paddle/tensor/to_string.py index fda011b71779c..8a32cefa37a26 100644 --- a/python/paddle/tensor/to_string.py +++ b/python/paddle/tensor/to_string.py @@ -257,7 +257,7 @@ def to_string(var, prefix='Tensor'): if var.dtype == core.VarDesc.VarType.BF16: var = var.astype('float32') - np_var = var.numpy() + np_var = var.numpy(False) if len(var.shape) == 0: size = 0 @@ -291,7 +291,8 @@ def _format_dense_tensor(tensor, indent): if tensor.dtype == core.VarDesc.VarType.BF16: tensor = tensor.astype('float32') - np_tensor = tensor.numpy() + # TODO(zhouwei): will remove 0D Tensor.numpy() hack + np_tensor = tensor.numpy(False) if len(tensor.shape) == 0: size = 0