Description
🐛 Bug
There is a mismatch in semantics between torchscript and pytorch when indexing with a list of integers.
In PyTorch (which follows numpy semantics), the list is converted into a int64
tensor, and advanced indexing is performed.
In torchscript, the list is treated as a tuple (in pytorch semantics), and it dispatches to several calls to select
.
So
x[[1, 2]]
becomes in TorchScript
x[1, 2]
which is not semantics-preserving.
To Reproduce
In [4]: @torch.jit.script
...: def f(x):
...: x[[1, 2]] = 1
...: return x
...:
In [5]: x = torch.zeros(4, 4)
In [6]: f(x)
Out[6]:
tensor([[0., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
In [7]: x = torch.zeros(4, 4)
In [8]: x[[1, 2]] = 1
In [9]: x
Out[9]:
tensor([[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]])
By inspecting f.code
, we see that it is performing the equivalent of x[1][2]
:
def forward(self,
x: Tensor) -> Tensor:
_0 = torch.select(torch.select(x, 0, 1), 0, 2)
_1 = torch.copy_(_0, 1)
return x
It should instead dispatch to index_fill_
or index_copy
(or more generally, index_put_
).
Note that the same (wrong) behavior happens for selection:
In [14]: @torch.jit.script
...: def g(x):
...: return x[[1, 2]]
...:
In [15]: a = torch.arange(16).reshape(4,4)
In [16]: a
Out[16]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
In [17]: g(x)
Out[17]: tensor(1.)
In [18]: print(g.code)
def forward(self,
x: Tensor) -> Tensor:
_0 = torch.select(torch.select(x, 0, 1), 0, 2)
return _0
Note that if instead we pass int64
tensors for the indices, everything work as expected.
I'm using PyTorch nightly 1.0.0.dev20190321
Context
This is potentially related to #14332, but the situation presented in #14332 gives a compiler error, while the one here silently fails
cc @suo