Skip to content

Commit e424fd7

Browse files
committed
Fixing LevelMapper.
1 parent d093eb9 commit e424fd7

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchvision/ops/poolers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def initLevelMapper(
4040
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
4141

4242

43-
class LevelMapper(object):
43+
class LevelMapper(nn.Module):
4444
"""Determine which FPN level each RoI in a set of RoIs should map to based
4545
on the heuristic in the FPN paper.
4646
@@ -60,13 +60,14 @@ def __init__(
6060
canonical_level: int = 4,
6161
eps: float = 1e-6,
6262
):
63+
super().__init__()
6364
self.k_min = k_min
6465
self.k_max = k_max
6566
self.s0 = canonical_scale
6667
self.lvl0 = canonical_level
6768
self.eps = eps
6869

69-
def __call__(self, boxlists: List[Tensor]) -> Tensor:
70+
def forward(self, boxlists: List[Tensor]) -> Tensor:
7071
"""
7172
Args:
7273
boxlists (list[BoxList])
@@ -117,8 +118,7 @@ class MultiScaleRoIAlign(nn.Module):
117118
"""
118119

119120
__annotations__ = {
120-
'scales': Optional[List[float]],
121-
'map_levels': Optional[LevelMapper]
121+
'scales': Optional[List[float]]
122122
}
123123

124124
def __init__(
@@ -137,7 +137,7 @@ def __init__(
137137
self.sampling_ratio = sampling_ratio
138138
self.output_size = tuple(output_size)
139139
self.scales = None
140-
self.map_levels = None
140+
self.map_levels: Optional[LevelMapper] = None
141141
self.canonical_scale = canonical_scale
142142
self.canonical_level = canonical_level
143143

0 commit comments

Comments
 (0)