diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 52b85244b3d..3f3322b7a88 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -35,7 +35,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: x = torch.transpose(x, 1, 2).contiguous() # flatten - x = x.view(batchsize, -1, height, width) + x = x.view(batchsize, num_channels, height, width) return x