diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..2dcf2dfc36 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -123,7 +123,10 @@ def arange(start, stop, step): def pytorch_funcify_Join(op, **kwargs): def join(axis, *tensors): # tensors could also be tuples, and in this case they don't have a ndim - tensors = [torch.tensor(tensor) for tensor in tensors] + tensors = [ + torch.tensor(tensor) if not torch.is_tensor(tensor) else tensor + for tensor in tensors + ] return torch.cat(tensors, dim=axis)