diff --git a/spec/draft/API_specification/manipulation_functions.rst b/spec/draft/API_specification/manipulation_functions.rst index 7eb7fa8b0..680efb55f 100644 --- a/spec/draft/API_specification/manipulation_functions.rst +++ b/spec/draft/API_specification/manipulation_functions.rst @@ -29,4 +29,5 @@ Objects in API roll squeeze stack + tile unstack diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 2bc929134..74538a8a3 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -10,6 +10,7 @@ "roll", "squeeze", "stack", + "tile", "unstack", ] @@ -257,6 +258,30 @@ def stack(arrays: Union[Tuple[array, ...], List[array]], /, *, axis: int = 0) -> """ +def tile(x: array, repetitions: Tuple[int, ...], /): + """ + Constructs an array by tiling an input array. + + Parameters + ---------- + x: array + input array. + repetitions: Tuple[int, ...] + number of repetitions along each axis (dimension). + + Let ``N = len(x.shape)`` and ``M = len(repetitions)``. + + If ``N > M``, the function must prepend ones until all axes (dimensions) are specified (e.g., if ``x`` has shape ``(8,6,4,2)`` and ``repetitions`` is the tuple ``(3,3)``, then ``repetitions`` must be treated as ``(1,1,3,3)``). + + If ``N < M``, the function must prepend singleton axes (dimensions) to ``x`` until ``x`` has as many axes (dimensions) as ``repetitions`` specifies (e.g., if ``x`` has shape ``(4,2)`` and ``repetitions`` is the tuple ``(3,3,3,3)``, then ``x`` must be treated as if it has shape ``(1,1,4,2)``). + + Returns + ------- + out: array + a tiled output array. The returned array must have the same data type as ``x`` and must have a rank (i.e., number of dimensions) equal to ``max(N, M)``. If ``S`` is the shape of the tiled array after prepending singleton dimensions (if necessary) and ``r`` is the tuple of repetitions after prepending ones (if necessary), then the number of elements along each axis (dimension) must satisfy ``S[i]*r[i]``, where ``i`` refers to the ``i`` th axis (dimension). + """ + + def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]: """ Splits an array in a sequence of arrays along the given axis.