|
1 | | -from typing import Optional, Tuple, Union |
| 1 | +from typing import Any, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | import torch.nn as nn |
5 | 5 |
|
6 | 6 | from captum.optim._param.image.transform import center_crop_shape |
7 | 7 | from captum.optim._utils.models import collect_activations |
8 | | -from captum.optim._utils.typing import ModelInputType, TransformSize |
| 8 | +from captum.optim._utils.typing import ModelInputType, PoolParam, TransformSize |
9 | 9 |
|
10 | 10 |
|
11 | 11 | def get_expanded_weights( |
@@ -56,8 +56,57 @@ def get_expanded_weights( |
56 | 56 | retain_graph=True, |
57 | 57 | )[0] |
58 | 58 | A.append(x.squeeze(0)) |
59 | | - exapnded_weights = torch.stack(A, 0) |
| 59 | + expanded_weights = torch.stack(A, 0) |
60 | 60 |
|
61 | 61 | if crop_shape is not None: |
62 | | - exapnded_weights = center_crop_shape(exapnded_weights, crop_shape) |
63 | | - return exapnded_weights |
| 62 | + expanded_weights = center_crop_shape(expanded_weights, crop_shape) |
| 63 | + return expanded_weights |
| 64 | + |
| 65 | + |
| 66 | +def max2avg_pool2d(model, value: Optional[Any] = float("-inf")) -> None: |
| 67 | + """ |
| 68 | + Replace all non-linear MaxPool2d layers with their linear AvgPool2d equivalents. |
| 69 | + This allows us to ignore non-linear values when calculating expanded weights. |
| 70 | +
|
| 71 | + Args: |
| 72 | + model (nn.Module): A PyTorch model instance. |
| 73 | + value (Any): Used to return any padding that's meant to be ignored by |
| 74 | + pooling layers back to zero. |
| 75 | + """ |
| 76 | + |
| 77 | + class AvgPool2dInf(torch.nn.Module): |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + kernel_size: PoolParam = 2, |
| 81 | + stride: Optional[PoolParam] = 2, |
| 82 | + padding: PoolParam = 0, |
| 83 | + ceil_mode: bool = False, |
| 84 | + value: Optional[Any] = None, |
| 85 | + ) -> None: |
| 86 | + super().__init__() |
| 87 | + self.avgpool = torch.nn.AvgPool2d( |
| 88 | + kernel_size=kernel_size, |
| 89 | + stride=stride, |
| 90 | + padding=padding, |
| 91 | + ceil_mode=ceil_mode, |
| 92 | + ) |
| 93 | + self.value = value |
| 94 | + |
| 95 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 96 | + x = self.avgpool(x) |
| 97 | + if self.value is not None: |
| 98 | + x[x == self.value] = 0.0 |
| 99 | + return x |
| 100 | + |
| 101 | + for name, child in model._modules.items(): |
| 102 | + if isinstance(child, torch.nn.MaxPool2d): |
| 103 | + new_layer = AvgPool2dInf( |
| 104 | + kernel_size=child.kernel_size, |
| 105 | + stride=child.stride, |
| 106 | + padding=child.padding, |
| 107 | + ceil_mode=child.ceil_mode, |
| 108 | + value=value, |
| 109 | + ) |
| 110 | + setattr(model, name, new_layer) |
| 111 | + elif child is not None: |
| 112 | + max2avg_pool2d(child) |
0 commit comments