diff --git a/ivy/functional/backends/paddle/data_type.py b/ivy/functional/backends/paddle/data_type.py index aa009c9f7533c..aa17a4c718178 100644 --- a/ivy/functional/backends/paddle/data_type.py +++ b/ivy/functional/backends/paddle/data_type.py @@ -118,13 +118,18 @@ def astype( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: dtype = ivy.as_native_dtype(dtype) - - if copy and 0 in x.shape: - return paddle.empty(x.shape, dtype=dtype) - - if x.dtype == dtype: - return x.clone() if copy else x - return x.clone().cast(dtype) if copy else x.cast(dtype) + if copy: + # Checking if the tensor is not empty + # As clone is not supported for empty tensors + if 0 in x.shape: + return paddle.to_tensor( + x, + dtype=dtype, + place=x.place, + stop_gradient=x.stop_gradient, + ) + return x.clone() if x.dtype == dtype else x.clone().cast(dtype) + return x if x.dtype == dtype else x.cast(dtype) def broadcast_arrays(*arrays: paddle.Tensor) -> List[paddle.Tensor]: diff --git a/ivy/functional/backends/paddle/manipulation.py b/ivy/functional/backends/paddle/manipulation.py index 900f684a6e85c..6ae7a7304f8e7 100644 --- a/ivy/functional/backends/paddle/manipulation.py +++ b/ivy/functional/backends/paddle/manipulation.py @@ -363,38 +363,42 @@ def repeat( def tile( x: paddle.Tensor, /, repeats: Sequence[int], *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: - if x.ndim >= 7: - repeats = ( - repeats.numpy().tolist() if isinstance(repeats, paddle.Tensor) else repeats - ) - new_shape = [*x.shape[:5], -1] - reshaped_tensor = paddle.reshape(x, new_shape) - new_repeats = repeats[:5] + [math.prod(repeats[5:])] - tiled_reshaped_tensor = tile(reshaped_tensor, new_repeats) - tiled_shape = tuple(s * r for s, r in zip(x.shape, repeats)) - result = paddle.reshape(tiled_reshaped_tensor, tiled_shape) - return result - if ivy.min(repeats) == 0: - # This logic is to mimic other backends behaviour when a 0 in repeat - # is received since paddle doesn't natively support it - if len(repeats) < x.ndim: + repeats = repeats.tolist() if isinstance(repeats, paddle.Tensor) else list(repeats) + # Paddle doesn't natively support repeats containing zeros + if 0 in x.shape or (len(repeats) > 0 and min(repeats) == 0): + if x.ndim == 0: + shape = repeats + elif len(repeats) <= x.ndim: shape = x.shape - shape[-len(repeats) :] = paddle_backend.multiply( - shape[-len(repeats) :], repeats - ).tolist() - elif len(repeats) > x.ndim: - shape = ( - repeats.tolist() - if isinstance(repeats, paddle.Tensor) - else list(repeats) - ) - shape[-x.ndim - 1 :] = paddle_backend.multiply( - shape[-x.ndim - 1 :], repeats - ).tolist() + shape[-len(repeats) :] = [ + s * r for s, r in zip(shape[-len(repeats) :], repeats) + ] else: - shape = paddle_backend.multiply(x.shape, repeats).tolist() - return paddle.zeros(shape).cast(x.dtype) - + shape = repeats.copy() + shape[-x.ndim :] = [s * r for r, s in zip(shape[-x.ndim :], x.shape)] + return paddle.empty(shape, dtype=x.dtype) + # Paddle doesn't natively support tensors containing more than 6 dimensions + if x.ndim > 6 or len(repeats) > 6: + if len(repeats) < x.ndim: + repeats = [1] * (x.ndim - len(repeats)) + repeats + elif len(repeats) > x.ndim: + shape = [1] * (len(repeats) - x.ndim) + x.shape + x = paddle.reshape(x, shape) + cur_shape = x.shape + cur_tensor = x + for i in range(0, x.ndim, 5): + size = 5 if i <= x.ndim - 5 else x.ndim - i + red_shape = [*cur_shape[:size], -1] + red_tensor = paddle.reshape(cur_tensor, red_shape) + red_repeats = [*repeats[i : i + size], 1] + tiled_red_tensor = paddle.tile(red_tensor, red_repeats) + perm = [size, *list(range(size))] + tiled_red_tensor = paddle.transpose(tiled_red_tensor, perm) + cur_shape = cur_shape[size:] + [ + s * r for s, r in zip(cur_shape[:size], repeats[i : i + size]) + ] + cur_tensor = paddle.reshape(tiled_red_tensor, cur_shape) + return cur_tensor return paddle.tile(x, repeats) diff --git a/ivy/functional/backends/paddle/statistical.py b/ivy/functional/backends/paddle/statistical.py index 6b231f3161d26..225169f6561e5 100644 --- a/ivy/functional/backends/paddle/statistical.py +++ b/ivy/functional/backends/paddle/statistical.py @@ -160,6 +160,16 @@ def axis_condition(axis): return ret.astype(ret_dtype) +def _calculate_reduced_shape(x, axis, keepdims): + if axis is None: + axis = tuple(range(len(x.shape))) + elif isinstance(axis, int): + axis = (axis,) + if keepdims: + return [1 if i in axis else x.shape[i] for i in range(len(x.shape))] + return [x.shape[i] for i in range(len(x.shape)) if i not in axis] + + @with_supported_dtypes( {"2.6.0 and below": ("bool", "complex", "float32", "float64")}, backend_version ) @@ -174,7 +184,10 @@ def mean( ) -> paddle.Tensor: if dtype is not None: x = ivy.astype(x, dtype).to_native() - if paddle.is_complex(x): + if 0 in x.shape: + shape = _calculate_reduced_shape(x, axis, keepdims) + ret = paddle.full(shape, float("nan")) + elif paddle.is_complex(x): ret = paddle.complex( paddle.mean(x.real(), axis=axis, keepdim=keepdims), paddle.mean(x.imag(), axis=axis, keepdim=keepdims), diff --git a/ivy/functional/backends/tensorflow/manipulation.py b/ivy/functional/backends/tensorflow/manipulation.py index ed13db5d98afa..44ea4856381d0 100644 --- a/ivy/functional/backends/tensorflow/manipulation.py +++ b/ivy/functional/backends/tensorflow/manipulation.py @@ -312,22 +312,19 @@ def tile( *, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: - if x.shape == (): - x = tf.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tf.Tensor) and repeats.shape == (): - repeats = tf.reshape(repeats, (-1,)) - # code to unify behaviour with numpy and torch - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tf.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats + # Unify behaviour with numpy and torch # TODO remove the unifying behaviour code if tensorflow handles this # https://github.com/tensorflow/tensorflow/issues/58002 + if len(repeats) < len(x.shape): + repeats = ( + repeats.numpy().tolist() + if isinstance(repeats, (tf.Tensor, tf.Variable)) + else list(repeats) + ) + repeats = [1] * (len(x.shape) - len(repeats)) + repeats + elif len(repeats) > len(x.shape): + shape = [1] * (len(repeats) - len(x.shape)) + x.shape + x = tf.reshape(x, shape) return tf.tile(x, repeats) diff --git a/ivy/functional/backends/torch/manipulation.py b/ivy/functional/backends/torch/manipulation.py index 8c4bdbb9eda22..bf84f6e76be28 100644 --- a/ivy/functional/backends/torch/manipulation.py +++ b/ivy/functional/backends/torch/manipulation.py @@ -266,7 +266,7 @@ def tile( ) -> torch.Tensor: if isinstance(repeats, torch.Tensor): repeats = repeats.detach().cpu().numpy().tolist() - return x.repeat(repeats) + return torch.tile(x, repeats) def constant_pad( diff --git a/ivy/functional/ivy/creation.py b/ivy/functional/ivy/creation.py index 5a837dcd93bcc..f0077024e3a2b 100644 --- a/ivy/functional/ivy/creation.py +++ b/ivy/functional/ivy/creation.py @@ -128,6 +128,9 @@ def _remove_np_bfloat16(obj): # from numpy arrays that have bfloat16 dtype using any extension because # bfloat16 in not supported natively by numpy (as of version <=1.25) if isinstance(obj, np.ndarray) and obj.dtype.name == "bfloat16": + # change dtype of empty array instead so that it doesn't lose its shape + if 0 in obj.shape: + return obj.astype(np.float32) return obj.tolist() return obj diff --git a/ivy/functional/ivy/experimental/manipulation.py b/ivy/functional/ivy/experimental/manipulation.py index a1467eddf7271..0aa7400b30856 100644 --- a/ivy/functional/ivy/experimental/manipulation.py +++ b/ivy/functional/ivy/experimental/manipulation.py @@ -2203,7 +2203,7 @@ def fill_diagonal( steps = ivy.arange(0, end, step) if isinstance(v, (ivy.Array, ivy.NativeArray)): v = ivy.reshape(v, (-1,)).astype(a.dtype) - v = ivy.tile(v, int(ivy.ceil(len(steps) / v.shape[0])))[: len(steps)] + v = ivy.tile(v, (int(ivy.ceil(len(steps) / v.shape[0])),))[: len(steps)] else: v = ivy.repeat(v, len(steps)) ivy.scatter_flat(steps, v, size=a.shape[0], reduction="replace", out=a) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 10a8b3385fbf3..2cf190bd2ac3d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -335,28 +335,16 @@ def _masked_fill_helper(draw): @st.composite def _repeat_helper(draw): shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) + st.shared(helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape") ) - - input_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, - ) - ) - - MAX_NUMPY_DIMS = 32 repeats = draw( st.lists( - st.integers(min_value=1, max_value=5), + st.integers(min_value=0, max_value=5), min_size=len(shape), - max_size=MAX_NUMPY_DIMS, + max_size=8, ) ) - assume(np.prod(repeats) * np.prod(shape) <= 2**28) - return input_dtype, x, repeats + return repeats @st.composite @@ -11332,12 +11320,19 @@ def test_torch_remainder_( class_tree=CLASS_TREE, init_tree="torch.tensor", method_name="repeat", - dtype_x_repeats=_repeat_helper(), - unpack_repeat=st.booleans(), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared( + helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape" + ), + ), + repeats=_repeat_helper(), + unpack_repeats=st.booleans(), ) def test_torch_repeat( - dtype_x_repeats, - unpack_repeat, + dtype_and_x, + repeats, + unpack_repeats, frontend_method_data, init_flags, method_flags, @@ -11345,20 +11340,12 @@ def test_torch_repeat( on_device, backend_fw, ): - input_dtype, x, repeats = dtype_x_repeats - - if backend_fw == "paddle": - # paddle only supports size of the shape of repeats - # to be less than or equal to 6 - assume(len(repeats) <= 6) - - repeat = { - "repeats": repeats, - } - if unpack_repeat: - method_flags.num_positional_args = len(repeat["repeats"]) + 1 - for i, x_ in enumerate(repeat["repeats"]): - repeat[f"x{i}"] = x_ + input_dtype, x = dtype_and_x + if unpack_repeats and len(repeats) > 0: + method_flags.num_positional_args = len(repeats) + method_kwargs = {f"x{i}": x_ for i, x_ in enumerate(repeats)} + else: + method_kwargs = {"repeats": repeats} helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -11366,7 +11353,7 @@ def test_torch_repeat( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np=repeat, + method_all_as_kwargs_np=method_kwargs, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py index 6745c947a6d5e..3bdc8866c7790 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py @@ -701,29 +701,25 @@ def test_swapaxes( @handle_test( fn_tree="functional.ivy.tile", dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=True), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - repeat=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape").map( - lambda rep: (len(rep),) + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared( + helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape" ), - min_value=0, - max_value=10, ), + repeats=st.lists(st.integers(min_value=0, max_value=5), max_size=8), ) -def test_tile(*, dtype_value, repeat, test_flags, backend_fw, fn_name, on_device): +def test_tile(*, dtype_value, repeats, test_flags, backend_fw, fn_name, on_device): dtype, value = dtype_value - repeat_dtype, repeat_list = repeat + # Empty tensors do not copy correctly in paddle + assume(backend_fw != "paddle" or 0 not in value[0].shape) helpers.test_function( - input_dtypes=dtype + repeat_dtype, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, x=value[0], - repeats=repeat_list[0], + repeats=repeats, rtol_=1e-2, atol_=1e-2, xs_grad_idxs=[[0, 0]],