From 82f869b4ebb18dd8cb8b19ee14dfb58e0f74192f Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 26 Apr 2021 11:46:57 -0600 Subject: [PATCH 1/7] Add the pre-trained InceptionV1 Places365 model --- captum/optim/models/__init__.py | 9 + captum/optim/models/_common.py | 1 - captum/optim/models/_image/inception_v1.py | 40 +- .../models/_image/inception_v1_places365.py | 352 ++++++++++++++++ .../_image/inception_v1_places365_classes.py | 375 ++++++++++++++++++ tests/optim/models/test_models.py | 197 +++++++-- 6 files changed, 913 insertions(+), 61 deletions(-) create mode 100644 captum/optim/models/_image/inception_v1_places365.py create mode 100644 captum/optim/models/_image/inception_v1_places365_classes.py diff --git a/captum/optim/models/__init__.py b/captum/optim/models/__init__.py index 635d1eb5b6..ee98f2ba8b 100755 --- a/captum/optim/models/__init__.py +++ b/captum/optim/models/__init__.py @@ -7,6 +7,13 @@ ) from ._image.inception5h_classes import INCEPTION5H_CLASSES # noqa: F401 from ._image.inception_v1 import InceptionV1, googlenet # noqa: F401 +from ._image.inception_v1_places365 import ( # noqa: F401 + InceptionV1Places365, + googlenet_places365, +) +from ._image.inception_v1_places365_classes import ( # noqa: F401 + INCEPTIONV1_PLACES365_CLASSES, +) __all__ = [ "RedirectedReluLayer", @@ -17,4 +24,6 @@ "InceptionV1", "googlenet", "INCEPTION5H_CLASSES", + "googlenet_places365", + "INCEPTIONV1_PLACES365_CLASSES", ] diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index d8976b4bc3..40ccd5d159 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -149,7 +149,6 @@ def __init__( out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, - padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index 65e73e32d9..f21853c20b 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -7,7 +7,7 @@ GS_SAVED_WEIGHTS_URL = ( "https://github.com/pytorch/captum/raw/" - + "optim-wip/captum/optim/_models/inception5h.pth" + + "optim-wip/captum/optim/models/_image/inception5h.pth" ) @@ -69,6 +69,8 @@ def googlenet( # Better version of Inception V1 / GoogleNet for Inception5h class InceptionV1(nn.Module): + __constants__ = ["aux_logits", "transform_input", "bgr_transform"] + def __init__( self, out_features: int = 1008, @@ -78,7 +80,7 @@ def __init__( replace_relus_with_redirectedrelu: bool = False, use_linear_modules_only: bool = False, ) -> None: - super(InceptionV1, self).__init__() + super().__init__() self.aux_logits = aux_logits self.transform_input = transform_input self.bgr_transform = bgr_transform @@ -99,7 +101,6 @@ def __init__( out_channels=64, kernel_size=(7, 7), stride=(2, 2), - padding=3, groups=1, bias=True, ) @@ -121,7 +122,6 @@ def __init__( out_channels=192, kernel_size=(3, 3), stride=(1, 1), - padding=1, groups=1, bias=True, ) @@ -230,7 +230,7 @@ def __init__( activ: Type[nn.Module] = nn.ReLU, p_layer: Type[nn.Module] = nn.MaxPool2d, ) -> None: - super(InceptionModule, self).__init__() + super().__init__() self.conv_1x1 = nn.Conv2d( in_channels=in_channels, out_channels=c1x1, @@ -254,7 +254,6 @@ def __init__( out_channels=c3x3, kernel_size=(3, 3), stride=(1, 1), - padding=1, groups=1, bias=True, ) @@ -273,7 +272,6 @@ def __init__( out_channels=c5x5, kernel_size=(5, 5), stride=(1, 1), - padding=1, groups=1, bias=True, ) @@ -311,9 +309,9 @@ def __init__( out_features: int = 1008, activ: Type[nn.Module] = nn.ReLU, ) -> None: - super(AuxBranch, self).__init__() + super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d((4, 4)) - self.loss_conv = nn.Conv2d( + self.conv = nn.Conv2d( in_channels=in_channels, out_channels=128, kernel_size=(1, 1), @@ -321,21 +319,19 @@ def __init__( groups=1, bias=True, ) - self.loss_conv_relu = activ() - self.loss_fc = nn.Linear(in_features=2048, out_features=1024, bias=True) - self.loss_fc_relu = activ() - self.loss_dropout = nn.Dropout(0.699999988079071) - self.loss_classifier = nn.Linear( - in_features=1024, out_features=out_features, bias=True - ) + self.conv_relu = activ() + self.fc1 = nn.Linear(in_features=2048, out_features=1024, bias=True) + self.fc1_relu = activ() + self.dropout = nn.Dropout(0.699999988079071) + self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.avg_pool(x) - x = self.loss_conv(x) - x = self.loss_conv_relu(x) + x = self.conv(x) + x = self.conv_relu(x) x = torch.flatten(x, 1) - x = self.loss_fc(x) - x = self.loss_fc_relu(x) - x = self.loss_dropout(x) - x = self.loss_classifier(x) + x = self.fc1(x) + x = self.fc1_relu(x) + x = self.dropout(x) + x = self.fc2(x) return x diff --git a/captum/optim/models/_image/inception_v1_places365.py b/captum/optim/models/_image/inception_v1_places365.py new file mode 100644 index 0000000000..dc1f704048 --- /dev/null +++ b/captum/optim/models/_image/inception_v1_places365.py @@ -0,0 +1,352 @@ +from typing import Any, Optional, Tuple, Type, Union, cast + +import torch +import torch.nn as nn + +from captum.optim.models._common import Conv2dSame, RedirectedReluLayer, SkipLayer + +GS_SAVED_WEIGHTS_URL = ( + "https://pytorch-tutorial-assets.s3.amazonaws.com/" + + "captum/inceptionv1_places365.pth" +) + + +def googlenet_places365( + pretrained: bool = False, + progress: bool = True, + model_path: Optional[str] = None, + **kwargs: Any +) -> "InceptionV1Places365": + r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from + `"Going Deeper with Convolutions" `_. + + The pretrained GoogleNet model was trained using the MIT Places365 Standard + dataset. See here for more information: https://arxiv.org/abs/1610.02055 + + Args: + pretrained (bool, optional): If True, returns a model pre-trained on the MIT + Places365 Standard dataset. + progress (bool, optional): If True, displays a progress bar of the download to + stderr + model_path (str, optional): Optional path for InceptionV1 model file. + replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained + model with Redirected ReLU in place of ReLU layers. + use_linear_modules_only (bool, optional): If True, return pretrained + model with all nonlinear layers replaced with linear equivalents. + aux_logits (bool, optional): If True, adds two auxiliary branches that can + improve training. Default: *True* + out_features (int, optional): Number of output features in the model used for + training. Default: 365 when pretrained is True. + transform_input (bool, optional): If True, preprocesses the input according to + the method with which it was trained on Places365. Default: *True* + """ + + if pretrained: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "replace_relus_with_redirectedrelu" not in kwargs: + kwargs["replace_relus_with_redirectedrelu"] = True + if "use_linear_modules_only" not in kwargs: + kwargs["use_linear_modules_only"] = False + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = True + if "out_features" not in kwargs: + kwargs["out_features"] = 365 + + model = InceptionV1Places365(**kwargs) + + if model_path is None: + state_dict = torch.hub.load_state_dict_from_url( + GS_SAVED_WEIGHTS_URL, progress=progress, check_hash=False + ) + else: + state_dict = torch.load(model_path, map_location="cpu") + model.load_state_dict(state_dict) + return model + + return InceptionV1Places365(**kwargs) + + +class InceptionV1Places365(nn.Module): + """ + MIT Places365 variant of the InceptionV1 model. + + Args: + out_features (int, optional): Number of output features in the model used for + training. Default: 365 when pretrained is True. + aux_logits (bool, optional): If True, adds two auxiliary branches that can + improve training. Default: *True* + transform_input (bool, optional): If True, preprocesses the input according to + the method with which it was trained on Places365. Default: *True* + replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained + model with Redirected ReLU in place of ReLU layers. + use_linear_modules_only (bool, optional): If True, return pretrained + model with all nonlinear layers replaced with linear equivalents. + """ + + __constants__ = ["aux_logits", "transform_input"] + + def __init__( + self, + out_features: int = 365, + aux_logits: bool = True, + transform_input: bool = True, + replace_relus_with_redirectedrelu: bool = False, + use_linear_modules_only: bool = False, + ) -> None: + super().__init__() + self.aux_logits = aux_logits + self.transform_input = transform_input + lrn_vals = (5, 9.999999747378752e-05, 0.75, 1.0) + + if use_linear_modules_only: + activ = SkipLayer + pool = nn.AvgPool2d + else: + if replace_relus_with_redirectedrelu: + activ = RedirectedReluLayer + else: + activ = nn.ReLU + pool = nn.MaxPool2d + + self.conv1 = Conv2dSame( + in_channels=3, + out_channels=64, + kernel_size=(7, 7), + stride=(2, 2), + groups=1, + bias=True, + ) + self.conv1_relu = activ() + self.pool1 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.local_response_norm1 = nn.LocalResponseNorm(*lrn_vals) + + self.conv2 = nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv2_relu = activ() + self.conv3 = Conv2dSame( + in_channels=64, + out_channels=192, + kernel_size=(3, 3), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv3_relu = activ() + self.local_response_norm2 = nn.LocalResponseNorm(*lrn_vals) + + self.pool2 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.mixed3a = InceptionModule(192, 64, 96, 128, 16, 32, 32, activ, pool) + self.mixed3a_relu = activ() + self.mixed3b = InceptionModule(256, 128, 128, 192, 32, 96, 64, activ, pool) + self.mixed3b_relu = activ() + self.pool3 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.mixed4a = InceptionModule(480, 192, 96, 208, 16, 48, 64, activ, pool) + self.mixed4a_relu = activ() + + if self.aux_logits: + self.aux1 = AuxBranch(512, out_features, activ) + + self.mixed4b = InceptionModule(512, 160, 112, 224, 24, 64, 64, activ, pool) + self.mixed4b_relu = activ() + self.mixed4c = InceptionModule(512, 128, 128, 256, 24, 64, 64, activ, pool) + self.mixed4c_relu = activ() + self.mixed4d = InceptionModule(512, 112, 144, 288, 32, 64, 64, activ, pool) + self.mixed4d_relu = activ() + + if self.aux_logits: + self.aux2 = AuxBranch(528, out_features, activ) + + self.mixed4e = InceptionModule(528, 256, 160, 320, 32, 128, 128, activ, pool) + self.mixed4e_relu = activ() + self.pool4 = pool(kernel_size=3, stride=2, padding=0, ceil_mode=True) + self.mixed5a = InceptionModule(832, 256, 160, 320, 32, 128, 128, activ, pool) + self.mixed5a_relu = activ() + self.mixed5b = InceptionModule(832, 384, 192, 384, 48, 128, 128, activ, pool) + self.mixed5b_relu = activ() + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = nn.Dropout(0.4000000059604645) + self.fc = nn.Linear(1024, out_features) + + def _transform_input(self, x: torch.Tensor) -> torch.Tensor: + if self.transform_input: + assert x.dim() == 3 or x.dim() == 4 + assert x.min() >= 0.0 and x.max() <= 1.0 + x = x.unsqueeze(0) if x.dim() == 3 else x + x = x * 255 - torch.tensor( + [116.7894, 112.6004, 104.0437], device=x.device + ).view(3, 1, 1) + x = x[:, [2, 1, 0]] # RGB to BGR + return x + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + x = self._transform_input(x) + x = self.conv1(x) + x = self.conv1_relu(x) + x = self.pool1(x) + x = self.local_response_norm1(x) + + x = self.conv2(x) + x = self.conv2_relu(x) + x = self.conv3(x) + x = self.conv3_relu(x) + x = self.local_response_norm2(x) + + x = self.pool2(x) + x = self.mixed3a_relu(self.mixed3a(x)) + x = self.mixed3b_relu(self.mixed3b(x)) + x = self.pool3(x) + x = self.mixed4a_relu(self.mixed4a(x)) + + if self.aux_logits: + aux1_output = self.aux1(x) + + x = self.mixed4b_relu(self.mixed4b(x)) + x = self.mixed4c_relu(self.mixed4c(x)) + x = self.mixed4d_relu(self.mixed4d(x)) + + if self.aux_logits: + aux2_output = self.aux2(x) + + x = self.mixed4e_relu(self.mixed4e(x)) + x = self.pool4(x) + x = self.mixed5a_relu(self.mixed5a(x)) + x = self.mixed5b_relu(self.mixed5b(x)) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.drop(x) + x = self.fc(x) + if not self.aux_logits: + return cast(torch.Tensor, x) + else: + return x, aux1_output, aux2_output + + +class InceptionModule(nn.Module): + def __init__( + self, + in_channels: int, + c1x1: int, + c3x3reduce: int, + c3x3: int, + c5x5reduce: int, + c5x5: int, + pool_proj: int, + activ: Type[nn.Module] = nn.ReLU, + p_layer: Type[nn.Module] = nn.MaxPool2d, + ) -> None: + super().__init__() + self.conv_1x1 = nn.Conv2d( + in_channels=in_channels, + out_channels=c1x1, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + + self.conv_3x3_reduce = nn.Conv2d( + in_channels=in_channels, + out_channels=c3x3reduce, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv_3x3_reduce_relu = activ() + self.conv_3x3 = Conv2dSame( + in_channels=c3x3reduce, + out_channels=c3x3, + kernel_size=(3, 3), + stride=(1, 1), + groups=1, + bias=True, + ) + + self.conv_5x5_reduce = nn.Conv2d( + in_channels=in_channels, + out_channels=c5x5reduce, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv_5x5_reduce_relu = activ() + self.conv_5x5 = Conv2dSame( + in_channels=c5x5reduce, + out_channels=c5x5, + kernel_size=(5, 5), + stride=(1, 1), + groups=1, + bias=True, + ) + + self.pool = p_layer(kernel_size=3, stride=1, padding=1, ceil_mode=True) + self.pool_proj = nn.Conv2d( + in_channels=in_channels, + out_channels=pool_proj, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + c1x1 = self.conv_1x1(x) + + c3x3 = self.conv_3x3_reduce(x) + c3x3 = self.conv_3x3_reduce_relu(c3x3) + c3x3 = self.conv_3x3(c3x3) + + c5x5 = self.conv_5x5_reduce(x) + c5x5 = self.conv_5x5_reduce_relu(c5x5) + c5x5 = self.conv_5x5(c5x5) + + px = self.pool(x) + px = self.pool_proj(px) + return torch.cat([c1x1, c3x3, c5x5, px], dim=1) + + +class AuxBranch(nn.Module): + def __init__( + self, + in_channels: int = 512, + out_features: int = 365, + activ: Type[nn.Module] = nn.ReLU, + ) -> None: + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d((4, 4)) + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=128, + kernel_size=(1, 1), + stride=(1, 1), + groups=1, + bias=True, + ) + self.conv_relu = activ() + self.fc1 = nn.Linear(in_features=2048, out_features=1024, bias=True) + self.fc1_relu = activ() + self.dropout = nn.Dropout(0.699999988079071) + self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.avg_pool(x) + x = self.conv(x) + x = self.conv_relu(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.fc1_relu(x) + x = self.dropout(x) + x = self.fc2(x) + return x diff --git a/captum/optim/models/_image/inception_v1_places365_classes.py b/captum/optim/models/_image/inception_v1_places365_classes.py new file mode 100644 index 0000000000..350f3b784d --- /dev/null +++ b/captum/optim/models/_image/inception_v1_places365_classes.py @@ -0,0 +1,375 @@ +from typing import List + +""" +List of classes for the MIT Places365 GoogleNet model trained using the Places365 +Standard dataset. Class list created from the Places365 GitHub repo class list: +https://github.com/CSAILVision/places365/blob/master/categories_places365.txt +""" + +INCEPTIONV1_PLACES365_CLASSES: List[str] = [ + "/a/airfield", + "/a/airplane_cabin", + "/a/airport_terminal", + "/a/alcove", + "/a/alley", + "/a/amphitheater", + "/a/amusement_arcade", + "/a/amusement_park", + "/a/apartment_building/outdoor", + "/a/aquarium", + "/a/aqueduct", + "/a/arcade", + "/a/arch", + "/a/archaelogical_excavation", + "/a/archive", + "/a/arena/hockey", + "/a/arena/performance", + "/a/arena/rodeo", + "/a/army_base", + "/a/art_gallery", + "/a/art_school", + "/a/art_studio", + "/a/artists_loft", + "/a/assembly_line", + "/a/athletic_field/outdoor", + "/a/atrium/public", + "/a/attic", + "/a/auditorium", + "/a/auto_factory", + "/a/auto_showroom", + "/b/badlands", + "/b/bakery/shop", + "/b/balcony/exterior", + "/b/balcony/interior", + "/b/ball_pit", + "/b/ballroom", + "/b/bamboo_forest", + "/b/bank_vault", + "/b/banquet_hall", + "/b/bar", + "/b/barn", + "/b/barndoor", + "/b/baseball_field", + "/b/basement", + "/b/basketball_court/indoor", + "/b/bathroom", + "/b/bazaar/indoor", + "/b/bazaar/outdoor", + "/b/beach", + "/b/beach_house", + "/b/beauty_salon", + "/b/bedchamber", + "/b/bedroom", + "/b/beer_garden", + "/b/beer_hall", + "/b/berth", + "/b/biology_laboratory", + "/b/boardwalk", + "/b/boat_deck", + "/b/boathouse", + "/b/bookstore", + "/b/booth/indoor", + "/b/botanical_garden", + "/b/bow_window/indoor", + "/b/bowling_alley", + "/b/boxing_ring", + "/b/bridge", + "/b/building_facade", + "/b/bullring", + "/b/burial_chamber", + "/b/bus_interior", + "/b/bus_station/indoor", + "/b/butchers_shop", + "/b/butte", + "/c/cabin/outdoor", + "/c/cafeteria", + "/c/campsite", + "/c/campus", + "/c/canal/natural", + "/c/canal/urban", + "/c/candy_store", + "/c/canyon", + "/c/car_interior", + "/c/carrousel", + "/c/castle", + "/c/catacomb", + "/c/cemetery", + "/c/chalet", + "/c/chemistry_lab", + "/c/childs_room", + "/c/church/indoor", + "/c/church/outdoor", + "/c/classroom", + "/c/clean_room", + "/c/cliff", + "/c/closet", + "/c/clothing_store", + "/c/coast", + "/c/cockpit", + "/c/coffee_shop", + "/c/computer_room", + "/c/conference_center", + "/c/conference_room", + "/c/construction_site", + "/c/corn_field", + "/c/corral", + "/c/corridor", + "/c/cottage", + "/c/courthouse", + "/c/courtyard", + "/c/creek", + "/c/crevasse", + "/c/crosswalk", + "/d/dam", + "/d/delicatessen", + "/d/department_store", + "/d/desert/sand", + "/d/desert/vegetation", + "/d/desert_road", + "/d/diner/outdoor", + "/d/dining_hall", + "/d/dining_room", + "/d/discotheque", + "/d/doorway/outdoor", + "/d/dorm_room", + "/d/downtown", + "/d/dressing_room", + "/d/driveway", + "/d/drugstore", + "/e/elevator/door", + "/e/elevator_lobby", + "/e/elevator_shaft", + "/e/embassy", + "/e/engine_room", + "/e/entrance_hall", + "/e/escalator/indoor", + "/e/excavation", + "/f/fabric_store", + "/f/farm", + "/f/fastfood_restaurant", + "/f/field/cultivated", + "/f/field/wild", + "/f/field_road", + "/f/fire_escape", + "/f/fire_station", + "/f/fishpond", + "/f/flea_market/indoor", + "/f/florist_shop/indoor", + "/f/food_court", + "/f/football_field", + "/f/forest/broadleaf", + "/f/forest_path", + "/f/forest_road", + "/f/formal_garden", + "/f/fountain", + "/g/galley", + "/g/garage/indoor", + "/g/garage/outdoor", + "/g/gas_station", + "/g/gazebo/exterior", + "/g/general_store/indoor", + "/g/general_store/outdoor", + "/g/gift_shop", + "/g/glacier", + "/g/golf_course", + "/g/greenhouse/indoor", + "/g/greenhouse/outdoor", + "/g/grotto", + "/g/gymnasium/indoor", + "/h/hangar/indoor", + "/h/hangar/outdoor", + "/h/harbor", + "/h/hardware_store", + "/h/hayfield", + "/h/heliport", + "/h/highway", + "/h/home_office", + "/h/home_theater", + "/h/hospital", + "/h/hospital_room", + "/h/hot_spring", + "/h/hotel/outdoor", + "/h/hotel_room", + "/h/house", + "/h/hunting_lodge/outdoor", + "/i/ice_cream_parlor", + "/i/ice_floe", + "/i/ice_shelf", + "/i/ice_skating_rink/indoor", + "/i/ice_skating_rink/outdoor", + "/i/iceberg", + "/i/igloo", + "/i/industrial_area", + "/i/inn/outdoor", + "/i/islet", + "/j/jacuzzi/indoor", + "/j/jail_cell", + "/j/japanese_garden", + "/j/jewelry_shop", + "/j/junkyard", + "/k/kasbah", + "/k/kennel/outdoor", + "/k/kindergarden_classroom", + "/k/kitchen", + "/l/lagoon", + "/l/lake/natural", + "/l/landfill", + "/l/landing_deck", + "/l/laundromat", + "/l/lawn", + "/l/lecture_room", + "/l/legislative_chamber", + "/l/library/indoor", + "/l/library/outdoor", + "/l/lighthouse", + "/l/living_room", + "/l/loading_dock", + "/l/lobby", + "/l/lock_chamber", + "/l/locker_room", + "/m/mansion", + "/m/manufactured_home", + "/m/market/indoor", + "/m/market/outdoor", + "/m/marsh", + "/m/martial_arts_gym", + "/m/mausoleum", + "/m/medina", + "/m/mezzanine", + "/m/moat/water", + "/m/mosque/outdoor", + "/m/motel", + "/m/mountain", + "/m/mountain_path", + "/m/mountain_snowy", + "/m/movie_theater/indoor", + "/m/museum/indoor", + "/m/museum/outdoor", + "/m/music_studio", + "/n/natural_history_museum", + "/n/nursery", + "/n/nursing_home", + "/o/oast_house", + "/o/ocean", + "/o/office", + "/o/office_building", + "/o/office_cubicles", + "/o/oilrig", + "/o/operating_room", + "/o/orchard", + "/o/orchestra_pit", + "/p/pagoda", + "/p/palace", + "/p/pantry", + "/p/park", + "/p/parking_garage/indoor", + "/p/parking_garage/outdoor", + "/p/parking_lot", + "/p/pasture", + "/p/patio", + "/p/pavilion", + "/p/pet_shop", + "/p/pharmacy", + "/p/phone_booth", + "/p/physics_laboratory", + "/p/picnic_area", + "/p/pier", + "/p/pizzeria", + "/p/playground", + "/p/playroom", + "/p/plaza", + "/p/pond", + "/p/porch", + "/p/promenade", + "/p/pub/indoor", + "/r/racecourse", + "/r/raceway", + "/r/raft", + "/r/railroad_track", + "/r/rainforest", + "/r/reception", + "/r/recreation_room", + "/r/repair_shop", + "/r/residential_neighborhood", + "/r/restaurant", + "/r/restaurant_kitchen", + "/r/restaurant_patio", + "/r/rice_paddy", + "/r/river", + "/r/rock_arch", + "/r/roof_garden", + "/r/rope_bridge", + "/r/ruin", + "/r/runway", + "/s/sandbox", + "/s/sauna", + "/s/schoolhouse", + "/s/science_museum", + "/s/server_room", + "/s/shed", + "/s/shoe_shop", + "/s/shopfront", + "/s/shopping_mall/indoor", + "/s/shower", + "/s/ski_resort", + "/s/ski_slope", + "/s/sky", + "/s/skyscraper", + "/s/slum", + "/s/snowfield", + "/s/soccer_field", + "/s/stable", + "/s/stadium/baseball", + "/s/stadium/football", + "/s/stadium/soccer", + "/s/stage/indoor", + "/s/stage/outdoor", + "/s/staircase", + "/s/storage_room", + "/s/street", + "/s/subway_station/platform", + "/s/supermarket", + "/s/sushi_bar", + "/s/swamp", + "/s/swimming_hole", + "/s/swimming_pool/indoor", + "/s/swimming_pool/outdoor", + "/s/synagogue/outdoor", + "/t/television_room", + "/t/television_studio", + "/t/temple/asia", + "/t/throne_room", + "/t/ticket_booth", + "/t/topiary_garden", + "/t/tower", + "/t/toyshop", + "/t/train_interior", + "/t/train_station/platform", + "/t/tree_farm", + "/t/tree_house", + "/t/trench", + "/t/tundra", + "/u/underwater/ocean_deep", + "/u/utility_room", + "/v/valley", + "/v/vegetable_garden", + "/v/veterinarians_office", + "/v/viaduct", + "/v/village", + "/v/vineyard", + "/v/volcano", + "/v/volleyball_court/outdoor", + "/w/waiting_room", + "/w/water_park", + "/w/water_tower", + "/w/waterfall", + "/w/watering_hole", + "/w/wave", + "/w/wet_bar", + "/w/wheat_field", + "/w/wind_farm", + "/w/windmill", + "/y/yard", + "/y/youth_hostel", + "/z/zen_garden", +] diff --git a/tests/optim/models/test_models.py b/tests/optim/models/test_models.py index 509009141a..56954f87e8 100644 --- a/tests/optim/models/test_models.py +++ b/tests/optim/models/test_models.py @@ -4,7 +4,7 @@ import torch -from captum.optim.models import googlenet +from captum.optim.models import googlenet, googlenet_places365 from captum.optim.models._common import RedirectedReluLayer, SkipLayer from tests.helpers.basic import BaseTest, assertTensorAlmostEqual @@ -39,8 +39,8 @@ class TestInceptionV1(BaseTest): def test_load_inceptionv1_with_redirected_relu(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping load pretrained inception" - + " due to insufficient Torch version." + "Skipping load pretrained InceptionV1 test due to insufficient Torch" + + " version." ) model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=True) _check_layer_in_model(self, model, RedirectedReluLayer) @@ -48,8 +48,8 @@ def test_load_inceptionv1_with_redirected_relu(self) -> None: def test_load_inceptionv1_no_redirected_relu(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping load pretrained inception RedirectedRelu" - + " due to insufficient Torch version." + "Skipping load pretrained InceptionV1 RedirectedRelu test due to" + + " insufficient Torch version." ) model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=False) _check_layer_not_in_model(self, model, RedirectedReluLayer) @@ -58,8 +58,8 @@ def test_load_inceptionv1_no_redirected_relu(self) -> None: def test_load_inceptionv1_linear(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping load pretrained inception linear" - + " due to insufficient Torch version." + "Skipping load pretrained InceptionV1 linear test due to insufficient" + + " Torch version." ) model = googlenet(pretrained=True, use_linear_modules_only=True) _check_layer_not_in_model(self, model, RedirectedReluLayer) @@ -68,11 +68,11 @@ def test_load_inceptionv1_linear(self) -> None: _check_layer_in_model(self, model, SkipLayer) _check_layer_in_model(self, model, torch.nn.AvgPool2d) - def test_transform_inceptionv1(self) -> None: + def test_inceptionv1_transform(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping inceptionV1 internal transform" - + " due to insufficient Torch version." + "Skipping InceptionV1 internal transform test due to insufficient" + + " Torch version." ) x = torch.randn(1, 3, 224, 224).clamp(0, 1) model = googlenet(pretrained=True) @@ -80,11 +80,11 @@ def test_transform_inceptionv1(self) -> None: expected_output = x * 255 - 117 assertTensorAlmostEqual(self, output, expected_output, 0) - def test_transform_bgr_inceptionv1(self) -> None: + def test_inceptionv1_transform_bgr(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping inceptionV1 internal transform" - + " BGR due to insufficient Torch version." + "Skipping InceptionV1 internal transform BGR test due to insufficient" + + " Torch version." ) x = torch.randn(1, 3, 224, 224).clamp(0, 1) model = googlenet(pretrained=True, bgr_transform=True) @@ -92,45 +92,166 @@ def test_transform_bgr_inceptionv1(self) -> None: expected_output = x[:, [2, 1, 0]] * 255 - 117 assertTensorAlmostEqual(self, output, expected_output, 0) - def test_load_and_forward_basic_inceptionv1(self) -> None: + def test_inceptionv1_forward(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping basic pretrained inceptionV1 forward" - + " due to insufficient Torch version." + "Skipping pretrained InceptionV1 forward test due to insufficient" + + " Torch version." ) - x = torch.randn(1, 3, 224, 224).clamp(0, 1) + x = torch.zeros(1, 3, 224, 224) model = googlenet(pretrained=True) - try: - model(x) - test = True - except Exception: - test = False - self.assertTrue(test) + outputs = model(x) + self.assertEqual(list(outputs.shape), [1, 1008]) - def test_load_and_forward_diff_sizes_inceptionv1(self) -> None: + def test_inceptionv1_load_and_forward_diff_sizes(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping pretrained inceptionV1 forward with different sized inputs" + "Skipping pretrained InceptionV1 forward with different sized inputs" + " due to insufficient Torch version." ) - x = torch.randn(1, 3, 512, 512).clamp(0, 1) - x2 = torch.randn(1, 3, 383, 511).clamp(0, 1) + x = torch.zeros(1, 3, 512, 512) + x2 = torch.zeros(1, 3, 383, 511) model = googlenet(pretrained=True) - try: - model(x) - model(x2) - test = True - except Exception: - test = False - self.assertTrue(test) + outputs = model(x) + outputs2 = model(x2) + self.assertEqual(list(outputs.shape), [1, 1008]) + self.assertEqual(list(outputs2.shape), [1, 1008]) + + def test_inceptionv1_forward_aux(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 with aux logits forward due to" + + " insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet(pretrained=False, aux_logits=True) + outputs = model(x) + self.assertEqual([list(o.shape) for o in outputs], [[1, 1008]] * 3) - def test_forward_aux_inceptionv1(self) -> None: + def test_inceptionv1_forward_cuda(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( - "Skipping pretrained inceptionV1 with aux logits forward" + "Skipping pretrained InceptionV1 forward CUDA test due to insufficient" + + " Torch version." + ) + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 forward CUDA test due to not" + + " supporting CUDA." + ) + x = torch.zeros(1, 3, 224, 224).cuda() + model = googlenet(pretrained=True).cuda() + outputs = model(x) + self.assertTrue(outputs.is_cuda) + self.assertEqual(list(outputs.shape), [1, 1008]) + + +class TestInceptionV1Places365(BaseTest): + def test_load_inceptionv1_places365_with_redirected_relu(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping load pretrained InceptionV1 Places365 due to insufficient" + + " Torch version." + ) + model = googlenet_places365( + pretrained=True, replace_relus_with_redirectedrelu=True + ) + _check_layer_in_model(self, model, RedirectedReluLayer) + + def test_load_inceptionv1_places365_no_redirected_relu(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping load pretrained InceptionV1 Places365 RedirectedRelu test" + " due to insufficient Torch version." ) + model = googlenet_places365( + pretrained=True, replace_relus_with_redirectedrelu=False + ) + _check_layer_not_in_model(self, model, RedirectedReluLayer) + _check_layer_in_model(self, model, torch.nn.ReLU) + + def test_load_inceptionv1_places365_linear(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping load pretrained InceptionV1 Places365 linear test due to" + + " insufficient Torch version." + ) + model = googlenet_places365(pretrained=True, use_linear_modules_only=True) + _check_layer_not_in_model(self, model, RedirectedReluLayer) + _check_layer_not_in_model(self, model, torch.nn.ReLU) + _check_layer_not_in_model(self, model, torch.nn.MaxPool2d) + _check_layer_in_model(self, model, SkipLayer) + _check_layer_in_model(self, model, torch.nn.AvgPool2d) + + def test_inceptionv1_places365_transform(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping InceptionV1 Places365 internal transform test due to" + + " insufficient Torch version." + ) x = torch.randn(1, 3, 224, 224).clamp(0, 1) - model = googlenet(pretrained=False, aux_logits=True) + model = googlenet_places365(pretrained=True) + output = model._transform_input(x) + expected_output = x * 255 - torch.tensor( + [116.7894, 112.6004, 104.0437], device=x.device + ).view(3, 1, 1) + expected_output = expected_output[:, [2, 1, 0]] + assertTensorAlmostEqual(self, output, expected_output, 0) + + def test_inceptionv1_places365_load_and_forward(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping basic pretrained InceptionV1 Places365 forward test due to" + + " insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365(pretrained=True) + outputs = model(x) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + + def test_inceptionv1_places365_load_and_forward_diff_sizes(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 forward with different" + + " sized inputs test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 512, 512) + x2 = torch.zeros(1, 3, 383, 511) + model = googlenet_places365(pretrained=True) + + outputs = model(x) + outputs2 = model(x2) + + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + self.assertEqual([list(o.shape) for o in outputs2], [[1, 365]] * 3) + + def test_inceptionv1_places365_forward_no_aux(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 with aux logits forward" + + " test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365(pretrained=False, aux_logits=False) outputs = model(x) - self.assertEqual(len(outputs), 3) + self.assertEqual(list(outputs.shape), [1, 365]) + + def test_inceptionv1_places365_forward_cuda(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 forward CUDA test due to" + + " insufficient Torch version." + ) + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 forward CUDA test due to" + + " not supporting CUDA." + ) + x = torch.zeros(1, 3, 224, 224).cuda() + model = googlenet_places365(pretrained=True).cuda() + outputs = model(x) + + self.assertTrue(outputs[0].is_cuda) + self.assertTrue(outputs[1].is_cuda) + self.assertTrue(outputs[2].is_cuda) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) From e578a335e74e8f362a1403e382d71c6f37312f27 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Tue, 4 May 2021 13:28:43 -0600 Subject: [PATCH 2/7] Change minimum torch version for Places365 model * The change in serialization format in torch 1.6 is backwards compatible, but not forward compatible and thus I'm raising the minimum torch version for tests involving the model. * Related issue where the new serialization format in 1.6 caused the error: https://github.com/pytorch/pytorch/issues/42239 --- tests/optim/models/test_models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/optim/models/test_models.py b/tests/optim/models/test_models.py index 56954f87e8..d656894bd8 100644 --- a/tests/optim/models/test_models.py +++ b/tests/optim/models/test_models.py @@ -148,7 +148,7 @@ def test_inceptionv1_forward_cuda(self) -> None: class TestInceptionV1Places365(BaseTest): def test_load_inceptionv1_places365_with_redirected_relu(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping load pretrained InceptionV1 Places365 due to insufficient" + " Torch version." @@ -159,7 +159,7 @@ def test_load_inceptionv1_places365_with_redirected_relu(self) -> None: _check_layer_in_model(self, model, RedirectedReluLayer) def test_load_inceptionv1_places365_no_redirected_relu(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping load pretrained InceptionV1 Places365 RedirectedRelu test" + " due to insufficient Torch version." @@ -171,7 +171,7 @@ def test_load_inceptionv1_places365_no_redirected_relu(self) -> None: _check_layer_in_model(self, model, torch.nn.ReLU) def test_load_inceptionv1_places365_linear(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping load pretrained InceptionV1 Places365 linear test due to" + " insufficient Torch version." @@ -184,7 +184,7 @@ def test_load_inceptionv1_places365_linear(self) -> None: _check_layer_in_model(self, model, torch.nn.AvgPool2d) def test_inceptionv1_places365_transform(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping InceptionV1 Places365 internal transform test due to" + " insufficient Torch version." @@ -199,7 +199,7 @@ def test_inceptionv1_places365_transform(self) -> None: assertTensorAlmostEqual(self, output, expected_output, 0) def test_inceptionv1_places365_load_and_forward(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping basic pretrained InceptionV1 Places365 forward test due to" + " insufficient Torch version." @@ -210,7 +210,7 @@ def test_inceptionv1_places365_load_and_forward(self) -> None: self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) def test_inceptionv1_places365_load_and_forward_diff_sizes(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping pretrained InceptionV1 Places365 forward with different" + " sized inputs test due to insufficient Torch version." @@ -226,7 +226,7 @@ def test_inceptionv1_places365_load_and_forward_diff_sizes(self) -> None: self.assertEqual([list(o.shape) for o in outputs2], [[1, 365]] * 3) def test_inceptionv1_places365_forward_no_aux(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping pretrained InceptionV1 Places365 with aux logits forward" + " test due to insufficient Torch version." @@ -237,7 +237,7 @@ def test_inceptionv1_places365_forward_no_aux(self) -> None: self.assertEqual(list(outputs.shape), [1, 365]) def test_inceptionv1_places365_forward_cuda(self) -> None: - if torch.__version__ <= "1.2.0": + if torch.__version__ <= "1.6.0": raise unittest.SkipTest( "Skipping pretrained InceptionV1 Places365 forward CUDA test due to" + " insufficient Torch version." From 94626c2ea0e5dc34765e8a6c2c11c06b03844beb Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 16 May 2021 09:05:33 -0600 Subject: [PATCH 3/7] Replace model input range assertions with UserWarning Sometimes input tensors can have values higher than 1 or lower than 0, like for example when using some of the features from Captum's attr module. Rather than disabling these checks, I've changed them into UserWarnings instead. --- captum/optim/models/_image/inception_v1.py | 4 ++- .../models/_image/inception_v1_places365.py | 4 ++- tests/optim/models/test_models.py | 26 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index f21853c20b..83ce7b57b0 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple, Type, Union, cast +from warnings import warn import torch import torch.nn as nn @@ -165,7 +166,8 @@ def __init__( def _transform_input(self, x: torch.Tensor) -> torch.Tensor: if self.transform_input: assert x.dim() == 3 or x.dim() == 4 - assert x.min() >= 0.0 and x.max() <= 1.0 + if x.min() < 0.0 or x.max() > 1.0: + warn("Model input has values outside of the range [0, 1].") x = x.unsqueeze(0) if x.dim() == 3 else x x = x * 255 - 117 x = x[:, [2, 1, 0]] if self.bgr_transform else x diff --git a/captum/optim/models/_image/inception_v1_places365.py b/captum/optim/models/_image/inception_v1_places365.py index dc1f704048..30c660841a 100644 --- a/captum/optim/models/_image/inception_v1_places365.py +++ b/captum/optim/models/_image/inception_v1_places365.py @@ -1,4 +1,5 @@ from typing import Any, Optional, Tuple, Type, Union, cast +from warnings import warn import torch import torch.nn as nn @@ -178,7 +179,8 @@ def __init__( def _transform_input(self, x: torch.Tensor) -> torch.Tensor: if self.transform_input: assert x.dim() == 3 or x.dim() == 4 - assert x.min() >= 0.0 and x.max() <= 1.0 + if x.min() < 0.0 or x.max() > 1.0: + warn("Model input has values outside of the range [0, 1].") x = x.unsqueeze(0) if x.dim() == 3 else x x = x * 255 - torch.tensor( [116.7894, 112.6004, 104.0437], device=x.device diff --git a/tests/optim/models/test_models.py b/tests/optim/models/test_models.py index d656894bd8..e895d2942c 100644 --- a/tests/optim/models/test_models.py +++ b/tests/optim/models/test_models.py @@ -80,6 +80,19 @@ def test_inceptionv1_transform(self) -> None: expected_output = x * 255 - 117 assertTensorAlmostEqual(self, output, expected_output, 0) + def test_inceptionv1_transform_warning(self) -> None: + if torch.__version__ <= "1.2.0": + raise unittest.SkipTest( + "Skipping InceptionV1 internal transform warning test due to" + + " insufficient Torch version." + ) + x = torch.stack( + [torch.ones(3, 112, 112) * -1, torch.ones(3, 112, 112) * 2], dim=0 + ) + model = googlenet(pretrained=True) + with self.assertWarns(UserWarning): + output = model._transform_input(x) + def test_inceptionv1_transform_bgr(self) -> None: if torch.__version__ <= "1.2.0": raise unittest.SkipTest( @@ -198,6 +211,19 @@ def test_inceptionv1_places365_transform(self) -> None: expected_output = expected_output[:, [2, 1, 0]] assertTensorAlmostEqual(self, output, expected_output, 0) + def test_inceptionv1_places365_transform_warning(self) -> None: + if torch.__version__ <= "1.6.0": + raise unittest.SkipTest( + "Skipping InceptionV1 Places365 internal transform warning test due" + + " to insufficient Torch version." + ) + x = torch.stack( + [torch.ones(3, 112, 112) * -1, torch.ones(3, 112, 112) * 2], dim=0 + ) + model = googlenet_places365(pretrained=True) + with self.assertWarns(UserWarning): + output = model._transform_input(x) + def test_inceptionv1_places365_load_and_forward(self) -> None: if torch.__version__ <= "1.6.0": raise unittest.SkipTest( From f8dba0259459f120297f976564f379bb8334686e Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 16 May 2021 09:12:55 -0600 Subject: [PATCH 4/7] Fix Flake8 errors --- tests/optim/models/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/optim/models/test_models.py b/tests/optim/models/test_models.py index e895d2942c..b6d70292df 100644 --- a/tests/optim/models/test_models.py +++ b/tests/optim/models/test_models.py @@ -91,7 +91,7 @@ def test_inceptionv1_transform_warning(self) -> None: ) model = googlenet(pretrained=True) with self.assertWarns(UserWarning): - output = model._transform_input(x) + model._transform_input(x) def test_inceptionv1_transform_bgr(self) -> None: if torch.__version__ <= "1.2.0": @@ -222,7 +222,7 @@ def test_inceptionv1_places365_transform_warning(self) -> None: ) model = googlenet_places365(pretrained=True) with self.assertWarns(UserWarning): - output = model._transform_input(x) + model._transform_input(x) def test_inceptionv1_places365_load_and_forward(self) -> None: if torch.__version__ <= "1.6.0": From 81f01a10a6eafd7e6fa18bf3dd517fe8def09d17 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 31 Jul 2021 09:28:48 -0600 Subject: [PATCH 5/7] Improve model docs --- captum/optim/models/_image/inception_v1.py | 108 ++++++++++++++++- .../models/_image/inception_v1_places365.py | 110 +++++++++++++++--- 2 files changed, 199 insertions(+), 19 deletions(-) diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index 83ce7b57b0..294ad89957 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -20,24 +20,37 @@ def googlenet( ) -> "InceptionV1": r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from `"Going Deeper with Convolutions" `_. + Args: + pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. + Default: False progress (bool, optional): If True, displays a progress bar of the download to stderr + Default: True model_path (str, optional): Optional path for InceptionV1 model file. + Default: None replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained model with Redirected ReLU in place of ReLU layers. + Default: *True* when pretrained is True otherwise *False* use_linear_modules_only (bool, optional): If True, return pretrained model with all nonlinear layers replaced with linear equivalents. + Default: False aux_logits (bool, optional): If True, adds two auxiliary branches that can - improve training. Default: *False* when pretrained is True otherwise *True* + improve training. + Default: False out_features (int, optional): Number of output features in the model used for - training. Default: 1008 when pretrained is True. + training. + Default: 1008 transform_input (bool, optional): If True, preprocesses the input according to - the method with which it was trained on ImageNet. Default: *False* + the method with which it was trained on ImageNet. + Default: False bgr_transform (bool, optional): If True and transform_input is True, perform an RGB to BGR transform in the internal preprocessing. - Default: *False* + Default: False + + Returns: + **InceptionV1** (InceptionV1): An Inception5h model. """ if pretrained: @@ -81,6 +94,28 @@ def __init__( replace_relus_with_redirectedrelu: bool = False, use_linear_modules_only: bool = False, ) -> None: + """ + Args: + + replace_relus_with_redirectedrelu (bool, optional): If True, return + pretrained model with Redirected ReLU in place of ReLU layers. + Default: False + use_linear_modules_only (bool, optional): If True, return pretrained + model with all nonlinear layers replaced with linear equivalents. + Default: False + aux_logits (bool, optional): If True, adds two auxiliary branches that can + improve training. + Default: False + out_features (int, optional): Number of output features in the model used + for training. + Default: 1008 + transform_input (bool, optional): If True, preprocesses the input according + to the method with which it was trained on ImageNet. + Default: False + bgr_transform (bool, optional): If True and transform_input is True, + perform an RGB to BGR transform in the internal preprocessing. + Default: False + """ super().__init__() self.aux_logits = aux_logits self.transform_input = transform_input @@ -164,6 +199,14 @@ def __init__( self.fc = nn.Linear(1024, out_features) def _transform_input(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor): A transformed tensor. + """ if self.transform_input: assert x.dim() == 3 or x.dim() == 4 if x.min() < 0.0 or x.max() > 1.0: @@ -176,6 +219,15 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor: def forward( self, x: torch.Tensor ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor or tuple of torch.Tensor): A single or multiple output + tensors from the model. + """ x = self._transform_input(x) x = self.conv1(x) x = self.conv1_relu(x) @@ -232,6 +284,24 @@ def __init__( activ: Type[nn.Module] = nn.ReLU, p_layer: Type[nn.Module] = nn.MaxPool2d, ) -> None: + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + inception module. + c1x1 (int, optional): + c3x3reduce (int, optional): + c3x3 (int, optional): + c5x5reduce (int, optional): + c5x5 (int, optional): + pool_proj (int, optional): + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + p_layer (type of nn.Module, optional): The nn.Module class type to use for + pooling layers. + Default: nn.MaxPool2d + """ super().__init__() self.conv_1x1 = nn.Conv2d( in_channels=in_channels, @@ -289,6 +359,14 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the Inception Module. + + Returns: + x (torch.Tensor): The output tensor of the Inception Module. + """ c1x1 = self.conv_1x1(x) c3x3 = self.conv_3x3_reduce(x) @@ -311,6 +389,19 @@ def __init__( out_features: int = 1008, activ: Type[nn.Module] = nn.ReLU, ) -> None: + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + auxiliary branch. + Default: 508 + out_features (int, optional): The number of output features to use for the + auxiliary branch. + Default: 1008 + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + """ super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d((4, 4)) self.conv = nn.Conv2d( @@ -328,6 +419,15 @@ def __init__( self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the auxiliary branch + module. + + Returns: + x (torch.Tensor): The output tensor of the auxiliary branch module. + """ x = self.avg_pool(x) x = self.conv(x) x = self.conv_relu(x) diff --git a/captum/optim/models/_image/inception_v1_places365.py b/captum/optim/models/_image/inception_v1_places365.py index 30c660841a..6e893451c5 100644 --- a/captum/optim/models/_image/inception_v1_places365.py +++ b/captum/optim/models/_image/inception_v1_places365.py @@ -26,20 +26,28 @@ def googlenet_places365( Args: pretrained (bool, optional): If True, returns a model pre-trained on the MIT - Places365 Standard dataset. + Places365 Standard dataset. + Default: False progress (bool, optional): If True, displays a progress bar of the download to stderr + Default: True model_path (str, optional): Optional path for InceptionV1 model file. + Default: None replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained model with Redirected ReLU in place of ReLU layers. + Default: *True* when pretrained is True otherwise *False* use_linear_modules_only (bool, optional): If True, return pretrained model with all nonlinear layers replaced with linear equivalents. + Default: False aux_logits (bool, optional): If True, adds two auxiliary branches that can - improve training. Default: *True* + improve training. + Default: True out_features (int, optional): Number of output features in the model used for training. Default: 365 when pretrained is True. + Default: 365 transform_input (bool, optional): If True, preprocesses the input according to - the method with which it was trained on Places365. Default: *True* + the method with which it was trained on Places365. + Default: True """ if pretrained: @@ -71,18 +79,6 @@ def googlenet_places365( class InceptionV1Places365(nn.Module): """ MIT Places365 variant of the InceptionV1 model. - - Args: - out_features (int, optional): Number of output features in the model used for - training. Default: 365 when pretrained is True. - aux_logits (bool, optional): If True, adds two auxiliary branches that can - improve training. Default: *True* - transform_input (bool, optional): If True, preprocesses the input according to - the method with which it was trained on Places365. Default: *True* - replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained - model with Redirected ReLU in place of ReLU layers. - use_linear_modules_only (bool, optional): If True, return pretrained - model with all nonlinear layers replaced with linear equivalents. """ __constants__ = ["aux_logits", "transform_input"] @@ -95,6 +91,25 @@ def __init__( replace_relus_with_redirectedrelu: bool = False, use_linear_modules_only: bool = False, ) -> None: + """ + Args: + + out_features (int, optional): Number of output features in the model used + for training. + Default: 365 + aux_logits (bool, optional): If True, adds two auxiliary branches that can + improve training. + Default: True + transform_input (bool, optional): If True, preprocesses the input according + to the method with which it was trained on Places365. + Default: True + replace_relus_with_redirectedrelu (bool, optional): If True, return + pretrained model with Redirected ReLU in place of ReLU layers. + Default: False + use_linear_modules_only (bool, optional): If True, return pretrained model + with all nonlinear layers replaced with linear equivalents. + Default: False + """ super().__init__() self.aux_logits = aux_logits self.transform_input = transform_input @@ -177,6 +192,14 @@ def __init__( self.fc = nn.Linear(1024, out_features) def _transform_input(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor): A transformed tensor. + """ if self.transform_input: assert x.dim() == 3 or x.dim() == 4 if x.min() < 0.0 or x.max() > 1.0: @@ -191,6 +214,15 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor: def forward( self, x: torch.Tensor ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Args: + + x (torch.Tensor): An input tensor to normalize and scale the values of. + + Returns: + x (torch.Tensor or tuple of torch.Tensor): A single or multiple output + tensors from the model. + """ x = self._transform_input(x) x = self.conv1(x) x = self.conv1_relu(x) @@ -247,6 +279,24 @@ def __init__( activ: Type[nn.Module] = nn.ReLU, p_layer: Type[nn.Module] = nn.MaxPool2d, ) -> None: + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + inception module. + c1x1 (int, optional): + c3x3reduce (int, optional): + c3x3 (int, optional): + c5x5reduce (int, optional): + c5x5 (int, optional): + pool_proj (int, optional): + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + p_layer (type of nn.Module, optional): The nn.Module class type to use for + pooling layers. + Default: nn.MaxPool2d + """ super().__init__() self.conv_1x1 = nn.Conv2d( in_channels=in_channels, @@ -304,6 +354,14 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the Inception Module. + + Returns: + x (torch.Tensor): The output tensor of the Inception Module. + """ c1x1 = self.conv_1x1(x) c3x3 = self.conv_3x3_reduce(x) @@ -326,6 +384,19 @@ def __init__( out_features: int = 365, activ: Type[nn.Module] = nn.ReLU, ) -> None: + """ + Args: + + in_channels (int, optional): The number of input channels to use for the + auxiliary branch. + Default: 508 + out_features (int, optional): The number of output features to use for the + auxiliary branch. + Default: 1008 + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: nn.ReLU + """ super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d((4, 4)) self.conv = nn.Conv2d( @@ -343,6 +414,15 @@ def __init__( self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to pass through the auxiliary branch + module. + + Returns: + x (torch.Tensor): The output tensor of the auxiliary branch module. + """ x = self.avg_pool(x) x = self.conv(x) x = self.conv_relu(x) From abf4535677e26184681c8cc6856e2f458f6fa281 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 12 Dec 2021 10:14:08 -0700 Subject: [PATCH 6/7] Add JIT support for models --- captum/optim/models/_common.py | 1 + captum/optim/models/_image/inception_v1.py | 2 +- .../models/_image/inception_v1_places365.py | 2 +- tests/optim/models/test_models.py | 50 +++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index fd468dfc1c..cf9a33955a 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -62,6 +62,7 @@ class RedirectedReluLayer(nn.Module): Class for applying RedirectedReLU """ + @torch.jit.ignore def forward(self, input: torch.Tensor) -> torch.Tensor: return RedirectedReLU.apply(input) diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index 294ad89957..a96e893e98 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -266,7 +266,7 @@ def forward( x = self.drop(x) x = self.fc(x) if not self.aux_logits: - return cast(torch.Tensor, x) + return x else: return x, aux1_output, aux2_output diff --git a/captum/optim/models/_image/inception_v1_places365.py b/captum/optim/models/_image/inception_v1_places365.py index 6e893451c5..930fd5797e 100644 --- a/captum/optim/models/_image/inception_v1_places365.py +++ b/captum/optim/models/_image/inception_v1_places365.py @@ -261,7 +261,7 @@ def forward( x = self.drop(x) x = self.fc(x) if not self.aux_logits: - return cast(torch.Tensor, x) + return x else: return x, aux1_output, aux2_output diff --git a/tests/optim/models/test_models.py b/tests/optim/models/test_models.py index b6d70292df..635b331094 100644 --- a/tests/optim/models/test_models.py +++ b/tests/optim/models/test_models.py @@ -158,6 +158,30 @@ def test_inceptionv1_forward_cuda(self) -> None: self.assertTrue(outputs.is_cuda) self.assertEqual(list(outputs.shape), [1, 1008]) + def test_inceptionv1_load_and_jit_module(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 load & JIT test" + + " due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet(pretrained=True) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual(list(outputs.shape), [1, 1008]) + + def test_inceptionv1_load_and_jit_module_no_redirected_relu(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 load & JIT with no" + + " redirected relu test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet(pretrained=True, replace_relus_with_redirectedrelu=False) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual(list(outputs.shape), [1, 1008]) + class TestInceptionV1Places365(BaseTest): def test_load_inceptionv1_places365_with_redirected_relu(self) -> None: @@ -281,3 +305,29 @@ def test_inceptionv1_places365_forward_cuda(self) -> None: self.assertTrue(outputs[1].is_cuda) self.assertTrue(outputs[2].is_cuda) self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + + def test_inceptionv1_places365_load_and_jit_module(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 load & JIT module test" + + " due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365(pretrained=True) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) + + def test_inceptionv1_places365_jit_module_no_redirected_relu(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping pretrained InceptionV1 Places365 load & JIT module with no" + + " redirected relu test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 224, 224) + model = googlenet_places365( + pretrained=True, replace_relus_with_redirectedrelu=False + ) + jit_model = torch.jit.script(model) + outputs = jit_model(x) + self.assertEqual([list(o.shape) for o in outputs], [[1, 365]] * 3) From 956c1cc17bf2da4d98d6882aff2da676d4f52d08 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 12 Dec 2021 10:21:59 -0700 Subject: [PATCH 7/7] Remove unused typing.cast import --- captum/optim/models/_image/inception_v1.py | 2 +- captum/optim/models/_image/inception_v1_places365.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/captum/optim/models/_image/inception_v1.py b/captum/optim/models/_image/inception_v1.py index a96e893e98..b9e534b91f 100644 --- a/captum/optim/models/_image/inception_v1.py +++ b/captum/optim/models/_image/inception_v1.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type, Union, cast +from typing import Optional, Tuple, Type, Union from warnings import warn import torch diff --git a/captum/optim/models/_image/inception_v1_places365.py b/captum/optim/models/_image/inception_v1_places365.py index 930fd5797e..8fb2fd8924 100644 --- a/captum/optim/models/_image/inception_v1_places365.py +++ b/captum/optim/models/_image/inception_v1_places365.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, Type, Union, cast +from typing import Any, Optional, Tuple, Type, Union from warnings import warn import torch