Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] fix Tensor.numpy, control whether to hack process to 1D #51757

Merged
merged 1 commit into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,29 @@ static PyObject* tensor_method_numpy(TensorObject* self,
size_t py_rank = tensor_dims.size();
size_t numel = 1;
if (py_rank == 0) {
// 0D Tensor hack process to 1D numpy, will remove in future
VLOG(0) << "Warning:: 0D Tensor cannot be used as Tensor.numpy()[0], Now "
"0D will be changed to 1D numpy to avoid this problem, but it's "
"not correct and will be removed in future. Please change "
"'Tensor.numpy()[0]' to 'float(Tensor)' or "
"'Tensor.numpy().item()' as soon as possible.";
py_rank = 1;
py_dims[0] = 1;
py_strides[0] = sizeof_dtype * numel;
Py_ssize_t args_num = PyTuple_Size(args);
bool set_to_1d = true;
if (args_num == (Py_ssize_t)1) {
PyObject* obj = PyTuple_GET_ITEM(args, 0);
if (obj == Py_False) {
set_to_1d = false;
}
}
if (set_to_1d) {
// 0D Tensor hack process to 1D numpy, will remove in future
VLOG(0)
<< "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In "
"order to avoid this problem, "
"0D Tensor will be changed to 1D numpy currently, but it's not "
"correct and will be "
"removed in future. Please modify "
" 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as "
"possible, "
"otherwise 'Tensor.numpy()[0]' will raise error";
py_rank = 1;
py_dims[0] = 1;
py_strides[0] = sizeof_dtype * numel;
}
} else {
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
py_dims[i] = static_cast<size_t>(tensor_dims[i]);
Expand All @@ -143,7 +157,7 @@ static PyObject* tensor_method_numpy(TensorObject* self,
PyObject* array = api.PyArray_NewFromDescr_(
api.PyArray_Type_,
api.PyArray_DescrFromType_(numpy_dtype),
tensor_dims.size(),
py_rank,
py_dims,
py_strides,
nullptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,9 @@ def _broadcast_final_loss(self):
), "train_batch() in last stage should obtain vaild loss"
loss = self.total_loss.detach()
is_fp32 = (
paddle.to_tensor(1)
paddle.full([], 1, 'int64')
if loss.dtype == paddle.float32
else paddle.to_tensor(0)
else paddle.full([], 0, 'int64')
)
paddle.distributed.broadcast(
is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
Expand All @@ -426,7 +426,7 @@ def _broadcast_final_loss(self):
loss, src=self.global_rank, sync_op=True, group=self.pp_group
)
else:
is_fp32 = paddle.to_tensor(1)
is_fp32 = paddle.full([], 1, 'int64')
paddle.distributed.broadcast(
is_fp32,
src=self._hcg.get_rank_from_stage(self.num_stages - 1),
Expand All @@ -435,7 +435,7 @@ def _broadcast_final_loss(self):
)
loss = (
paddle.zeros(shape=[1], dtype="float32")
if is_fp32.numpy()[0]
if is_fp32.item()
else paddle.zeros(shape=[1], dtype="float16")
)
paddle.distributed.broadcast(
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def __array__(self, dtype=None):
print(type(x_array)) #<class 'numpy.ndarray'>
print(x_array.shape) #(2, 2)
"""
array = self.numpy()
array = self.numpy(False)
if dtype:
array = array.astype(dtype)
return array
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def _var_base_to_np(var_base):
"paddle.fluid.framework._var_base_to_np is deprecated, please use var_base.numpy() instead of _var_base_to_np(var_base)."
)

return var_base.numpy()
return var_base.numpy(False)


def _cpu_num():
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,10 +698,10 @@ def unsqueeze(input, axes, name=None):
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, Variable):
axes = axes.numpy().tolist()
axes = axes.tolist()
elif isinstance(axes, (list, tuple)):
axes = [
item.numpy().item(0) if isinstance(item, Variable) else item
item.item(0) if isinstance(item, Variable) else item
for item in axes
]
return _C_ops.unsqueeze(input, axes)
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,12 +2432,12 @@ def _append_optimize_op(self, block, param_and_grad):
_beta1 = (
self._beta1
if not isinstance(self._beta1, Variable)
else self._beta1.numpy().item(0)
else self._beta1.item(0)
)
_beta2 = (
self._beta2
if not isinstance(self._beta2, Variable)
else self._beta2.numpy().item(0)
else self._beta2.item(0)
)
master_weight = None
_, _, _, _, _, _ = _legacy_C_ops.adam(
Expand Down
10 changes: 9 additions & 1 deletion python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,15 @@ def test_tolist(self):

def test_numpy(self):
x = paddle.full([], 0.5)
np.testing.assert_array_equal(x.numpy(), np.array(0.5))
# 0D Tensor hack to 1D Numpy defaut, will remove in future
x_np = x.numpy()
np.testing.assert_array_equal(x_np.shape, (1,))
np.testing.assert_array_equal(x_np, np.array([0.5]))

# return origin correct numpy
x_np = x.numpy(False)
np.testing.assert_array_equal(x_np.shape, ())
np.testing.assert_array_equal(x_np, np.array(0.5))

def test_numel(self):
out = paddle.numel(self.x)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/geometric/message_passing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def convert_out_size_to_list(out_size):
elif isinstance(out_size, (int, np.int32, np.int64)):
out_size = [out_size]
else:
out_size = [out_size.numpy().astype(int)[0]]
out_size = [int(out_size)]
return out_size


Expand Down
2 changes: 1 addition & 1 deletion python/paddle/incubate/operators/graph_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def convert_out_size_to_list(out_size):
elif isinstance(out_size, (int, np.int32, np.int64)):
out_size = [out_size]
else:
out_size = [out_size.numpy().astype(int)[0]]
out_size = [int(out_size)]
return out_size


Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _is_list_or_turple_(data):

for i, dim in enumerate(out_shape):
if isinstance(dim, Variable):
out_shape[i] = dim.numpy().item()
out_shape[i] = dim.item()
if not (_is_list_or_turple_(out_shape)):
raise TypeError("size should be a list or tuple or Variable.")
# Validate the shape
Expand Down Expand Up @@ -1693,7 +1693,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):

if in_dygraph_mode():
if isinstance(pad, Variable):
pad = pad.numpy().tolist()
pad = pad.tolist()
out = _C_ops.pad3d(x, pad, mode, value, data_format)
else:
attrs = {'mode': mode, 'value': value, 'data_format': data_format}
Expand Down
8 changes: 2 additions & 6 deletions python/paddle/nn/functional/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,13 @@ def affine_grid(theta, out_shape, align_corners=True, name=None):

if in_dygraph_mode():
_out_shape = (
out_shape.numpy().tolist()
if isinstance(out_shape, Variable)
else out_shape
out_shape.tolist() if isinstance(out_shape, Variable) else out_shape
)
theta = theta._use_gpudnn(use_cudnn)
return _C_ops.affine_grid(theta, _out_shape, align_corners)
elif in_dynamic_mode():
_out_shape = (
out_shape.numpy().tolist()
if isinstance(out_shape, Variable)
else out_shape
out_shape.tolist() if isinstance(out_shape, Variable) else out_shape
)
return _legacy_C_ops.affine_grid(
theta,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/initializer/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,6 @@ def __init__(self, value, name=None):

# TODO: value is already is a tensor, accounting efficiency maybe it does not need to convert tensor to numpy data and then initialized.
if isinstance(value, paddle.static.Variable):
value = value.numpy()
value = value.numpy(False)

super().__init__(value)
8 changes: 4 additions & 4 deletions python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ def _append_optimize_op(self, block, param_and_grad):
_beta1 = (
self._beta1
if not isinstance(self._beta1, Variable)
else self._beta1.numpy().item(0)
else self._beta1.item(0)
)
_beta2 = (
self._beta2
if not isinstance(self._beta2, Variable)
else self._beta2.numpy().item(0)
else self._beta2.item(0)
)

_, _, _, _, _, _ = _C_ops.adam_(
Expand Down Expand Up @@ -623,12 +623,12 @@ def _append_optimize_multi_tensor_op(
_beta1 = (
self._beta1
if not isinstance(self._beta1, Variable)
else self._beta1.numpy().item(0)
else self._beta1.item(0)
)
_beta2 = (
self._beta2
if not isinstance(self._beta2, Variable)
else self._beta2.numpy().item(0)
else self._beta2.item(0)
)

if framework.in_dygraph_mode():
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,12 @@ def _append_optimize_op(self, block, param_and_grad):
_beta1 = (
self._beta1
if not isinstance(self._beta1, Variable)
else self._beta1.numpy().item(0)
else self._beta1.item(0)
)
_beta2 = (
self._beta2
if not isinstance(self._beta2, Variable)
else self._beta2.numpy().item(0)
else self._beta2.item(0)
)

_, _, _, _, _, _ = _C_ops.adamw_(
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/quantization/imperative/ptq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def abs_max_value(tensor):
return float(paddle.max(paddle.abs(tensor)).numpy())
return float(paddle.max(paddle.abs(tensor)))


def merge_max_value(old, new):
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def body(i, ten):
)

if _non_static_mode():
now_cond = pre_cond.numpy().item()
now_cond = pre_cond.item()
while now_cond:
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
Expand All @@ -476,7 +476,7 @@ def body(i, ten):
"body in while_loop should return the same arity "
"(length and structure) and types as loop_vars"
)
now_cond = cond(*output_vars).numpy().item()
now_cond = cond(*output_vars).item()
map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
return loop_vars

Expand Down Expand Up @@ -968,7 +968,7 @@ def false_func():
if _non_static_mode():
assert isinstance(pred, Variable), "The pred in cond must be Variable"
assert pred.size == 1, "condition input's numel should be 1"
pred = pred.numpy().item()
pred = pred.item()
if pred:
if true_fn is not None:
if not callable(true_fn):
Expand Down
31 changes: 11 additions & 20 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def slice(input, axes, starts, ends):

if isinstance(starts, (list, tuple)):
starts = [
item.numpy().item(0)
if isinstance(item, tmp_tensor_type)
else item
item.item(0) if isinstance(item, tmp_tensor_type) else item
for item in starts
]
elif isinstance(starts, tmp_tensor_type):
Expand All @@ -341,9 +339,7 @@ def slice(input, axes, starts, ends):

if isinstance(ends, (list, tuple)):
ends = [
item.numpy().item(0)
if isinstance(item, tmp_tensor_type)
else item
item.item(0) if isinstance(item, tmp_tensor_type) else item
for item in ends
]
elif isinstance(ends, tmp_tensor_type):
Expand Down Expand Up @@ -1068,7 +1064,8 @@ def tolist(x):
print(expectlist) #[0, 1, 2, 3, 4]

"""
return x.numpy().tolist()
# TODO(zhouwei): will remove 0D Tensor.numpy() hack
return x.numpy(False).tolist()


def concat(x, axis=0, name=None):
Expand Down Expand Up @@ -1117,7 +1114,6 @@ def concat(x, axis=0, name=None):
input = x
if in_dygraph_mode():
if isinstance(axis, Variable):
axis = axis.numpy()
axis = axis.item(0)
if not isinstance(input, Variable):
input = [t for t in input if t.shape.count(0) == 0]
Expand Down Expand Up @@ -1952,7 +1948,6 @@ def split(x, num_or_sections, axis=0, name=None):
dim = axis
if in_dygraph_mode():
if isinstance(dim, Variable):
dim = dim.numpy()
dim = dim.item(0)
assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0"
dim = (len(input.shape) + dim) if dim < 0 else dim
Expand All @@ -1961,9 +1956,7 @@ def split(x, num_or_sections, axis=0, name=None):
if paddle.utils._contain_var(num_or_sections):
for index, item in enumerate(num_or_sections):
if isinstance(item, Variable):
num_or_sections[index] = num_or_sections[index].numpy()[
0
]
num_or_sections[index] = num_or_sections[index].item()
elif not isinstance(num_or_sections, int):
raise TypeError(
"The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but "
Expand Down Expand Up @@ -2593,10 +2586,10 @@ def unsqueeze(x, axis, name=None):
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, Variable):
axes = axes.numpy().tolist()
axes = axes.tolist()
elif isinstance(axes, (list, tuple)):
axes = [
item.numpy().item(0) if isinstance(item, Variable) else item
item.item(0) if isinstance(item, Variable) else item
for item in axes
]
return _C_ops.unsqueeze(input, axes)
Expand Down Expand Up @@ -2659,10 +2652,10 @@ def unsqueeze_(x, axis, name=None):
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, Variable):
axes = axes.numpy().tolist()
axes = axes.tolist()
elif isinstance(axes, (list, tuple)):
axes = [
item.numpy().item(0) if isinstance(item, Variable) else item
item.item(0) if isinstance(item, Variable) else item
for item in axes
]
return _C_ops.unsqueeze_(input, axes)
Expand Down Expand Up @@ -3148,7 +3141,7 @@ def tile(x, repeat_times, name=None):
assert (
repeat_times.ndim == 1
), "Only support ndim == 1 while repeat_times is a Tensor."
repeat_times = repeat_times.numpy().tolist()
repeat_times = repeat_times.tolist()

return _C_ops.tile(x, repeat_times)
else:
Expand Down Expand Up @@ -3648,9 +3641,7 @@ def reshape_(x, shape, name=None):
tmp_tensor_type = core.eager.Tensor
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0)
if isinstance(item, tmp_tensor_type)
else item
item.item(0) if isinstance(item, tmp_tensor_type) else item
for item in shape
]
if shape == x.shape:
Expand Down
Loading