|
| 1 | +from typing import Tuple, Union, cast |
| 2 | + |
1 | 3 | import torch |
2 | 4 | import torch.nn as nn |
3 | 5 | import torch.nn.functional as F |
@@ -25,10 +27,15 @@ def googlenet( |
25 | 27 | training. Default: 1008 when pretrained is True. |
26 | 28 | transform_input (bool): If True, preprocesses the input according to |
27 | 29 | the method with which it was trained on ImageNet. Default: *False* |
| 30 | + bgr_transform (bool): If True and transform_input is True, perform an |
| 31 | + RGB to BGR transform in the internal preprocessing. |
| 32 | + Default: *False* |
28 | 33 | """ |
29 | 34 | if pretrained: |
30 | 35 | if "transform_input" not in kwargs: |
31 | 36 | kwargs["transform_input"] = True |
| 37 | + if "bgr_transform" not in kwargs: |
| 38 | + kwargs["bgr_transform"] = False |
32 | 39 | if "aux_logits" not in kwargs: |
33 | 40 | kwargs["aux_logits"] = False |
34 | 41 | if "out_features" not in kwargs: |
@@ -56,10 +63,12 @@ def __init__( |
56 | 63 | out_features: int = 1008, |
57 | 64 | aux_logits: bool = False, |
58 | 65 | transform_input: bool = False, |
| 66 | + bgr_transform: bool = False, |
59 | 67 | ) -> None: |
60 | 68 | super(InceptionV1, self).__init__() |
61 | 69 | self.aux_logits = aux_logits |
62 | 70 | self.transform_input = transform_input |
| 71 | + self.bgr_transform = bgr_transform |
63 | 72 | lrn_vals = (9, 9.99999974738e-05, 0.5, 1.0) |
64 | 73 |
|
65 | 74 | self.conv1 = nn.Conv2d( |
@@ -125,10 +134,12 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor: |
125 | 134 | assert x.min() >= 0.0 and x.max() <= 1.0 |
126 | 135 | x = x.unsqueeze(0) if x.dim() == 3 else x |
127 | 136 | x = x * 255 - 117 |
128 | | - x = x.clone()[:, [2, 1, 0]] # RGB to BGR |
| 137 | + x = x[:, [2, 1, 0]] if self.bgr_transform else x |
129 | 138 | return x |
130 | 139 |
|
131 | | - def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 140 | + def forward( |
| 141 | + self, x: torch.Tensor |
| 142 | + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
132 | 143 | x = self._transform_input(x) |
133 | 144 | x = F.pad(x, (2, 3, 2, 3)) |
134 | 145 | x = self.conv1(x) |
@@ -173,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
173 | 184 | x = self.drop(x) |
174 | 185 | x = self.fc(x) |
175 | 186 | if not self.aux_logits: |
176 | | - return x |
| 187 | + return cast(torch.Tensor, x) |
177 | 188 | else: |
178 | 189 | return x, aux1_output, aux2_output |
179 | 190 |
|
|
0 commit comments