Skip to content

Commit

Permalink
convert torch.split return to list in RAFT (#7597)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored May 19, 2023
1 parent 689ff29 commit d2f7486
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchvision/prototype/models/depth/stereo/raft_stereo.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,12 @@ def forward(
hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1)
hidden_states.append(torch.tanh(hidden_state))
contexts.append(
torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1)
# mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with
# `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by
# JIT and thus we have to keep the wrong annotation here and silence mypy.
torch.split( # type: ignore[arg-type]
context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1
)
)

_, Cf, Hf, Wf = fmap1.shape
Expand Down

0 comments on commit d2f7486

Please sign in to comment.