From c60b3a2098e1e5cabf8d720cfd5a87dcb7b59b2f Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Fri, 9 Feb 2024 03:36:01 +0000 Subject: [PATCH] fix: allow torch backend split to infer final split size (#28230) --- ivy/functional/backends/torch/manipulation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ivy/functional/backends/torch/manipulation.py b/ivy/functional/backends/torch/manipulation.py index 3e18e83eb33c9..c124e3aab9a38 100644 --- a/ivy/functional/backends/torch/manipulation.py +++ b/ivy/functional/backends/torch/manipulation.py @@ -241,6 +241,10 @@ def split( torch.tensor(dim_size) / torch.tensor(num_or_size_splits) ) elif isinstance(num_or_size_splits, list): + if num_or_size_splits[-1] == -1: + # infer the final size split + remaining_size = dim_size - sum(num_or_size_splits[:-1]) + num_or_size_splits[-1] = remaining_size num_or_size_splits = tuple(num_or_size_splits) return list(torch.split(x, num_or_size_splits, axis))