diff --git a/ivy/functional/backends/torch/manipulation.py b/ivy/functional/backends/torch/manipulation.py index 3e18e83eb33c9..c124e3aab9a38 100644 --- a/ivy/functional/backends/torch/manipulation.py +++ b/ivy/functional/backends/torch/manipulation.py @@ -241,6 +241,10 @@ def split( torch.tensor(dim_size) / torch.tensor(num_or_size_splits) ) elif isinstance(num_or_size_splits, list): + if num_or_size_splits[-1] == -1: + # infer the final size split + remaining_size = dim_size - sum(num_or_size_splits[:-1]) + num_or_size_splits[-1] = remaining_size num_or_size_splits = tuple(num_or_size_splits) return list(torch.split(x, num_or_size_splits, axis))