Skip to content

Commit 9540c29

Browse files
bottlerfacebook-github-bot
authored andcommitted
Make Module.__init__ automatic
Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else. Reviewed By: shapovalov Differential Revision: D42760349 fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
1 parent 97f8f9b commit 9540c29

29 files changed

+36
-87
lines changed

Diff for: projects/implicitron_trainer/README.md

-2
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
212212
class XRayRenderer(BaseRenderer, torch.nn.Module):
213213
n_pts_per_ray: int = 64
214214
215-
# if there are other base classes, make sure to call `super().__init__()` explicitly
216215
def __post_init__(self):
217-
super().__init__()
218216
# custom initialization
219217
220218
def forward(

Diff for: pytorch3d/implicitron/eval_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
130130
raise ValueError("Image size should be set in the dataset")
131131

132132
# init the simple DBIR model
133-
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
133+
model = ModelDBIR(
134134
render_image_width=image_size,
135135
render_image_height=image_size,
136136
bg_color=bg_color,

Diff for: pytorch3d/implicitron/models/base_model.py

-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
4949
# the training loop.
5050
log_vars: List[str] = field(default_factory=lambda: ["objective"])
5151

52-
def __init__(self) -> None:
53-
super().__init__()
54-
5552
def forward(
5653
self,
5754
*, # force keyword-only arguments

Diff for: pytorch3d/implicitron/models/feature_extractor/feature_extractor.py

-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
1515
Base class for an extractor of a set of features from images.
1616
"""
1717

18-
def __init__(self):
19-
super().__init__()
20-
2118
def get_feat_dims(self) -> int:
2219
"""
2320
Returns:

Diff for: pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py

-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
7878
feature_rescale: float = 1.0
7979

8080
def __post_init__(self):
81-
super().__init__()
8281
if self.normalize_image:
8382
# register buffers needed to normalize the image
8483
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):

Diff for: pytorch3d/implicitron/models/generic_model.py

-2
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
304304
)
305305

306306
def __post_init__(self):
307-
super().__init__()
308-
309307
if self.view_pooler_enabled:
310308
if self.image_feature_extractor_class_type is None:
311309
raise ValueError(

Diff for: pytorch3d/implicitron/models/global_encoder/autodecoder.py

-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
2929
ignore_input: bool = False
3030

3131
def __post_init__(self):
32-
super().__init__()
33-
3432
if self.n_instances <= 0:
3533
raise ValueError(f"Invalid n_instances {self.n_instances}")
3634

Diff for: pytorch3d/implicitron/models/global_encoder/global_encoder.py

-5
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
2626
(`SequenceAutodecoder`).
2727
"""
2828

29-
def __init__(self) -> None:
30-
super().__init__()
31-
3229
def get_encoding_dim(self):
3330
"""
3431
Returns the dimensionality of the returned encoding.
@@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
6966
autodecoder: Autodecoder
7067

7168
def __post_init__(self):
72-
super().__init__()
7369
run_auto_creation(self)
7470

7571
def get_encoding_dim(self):
@@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
10399
time_divisor: float = 1.0
104100

105101
def __post_init__(self):
106-
super().__init__()
107102
self._harmonic_embedding = HarmonicEmbedding(
108103
n_harmonic_functions=self.n_harmonic_functions,
109104
append_input=self.append_input,

Diff for: pytorch3d/implicitron/models/implicit_function/base.py

-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515

1616
class ImplicitFunctionBase(ABC, ReplaceableBase):
17-
def __init__(self):
18-
super().__init__()
19-
2017
@abstractmethod
2118
def forward(
2219
self,

Diff for: pytorch3d/implicitron/models/implicit_function/decoding_functions.py

-7
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
4545
space and transforms it into the required quantity (for example density and color).
4646
"""
4747

48-
def __post_init__(self):
49-
super().__init__()
50-
5148
def forward(
5249
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
5350
) -> torch.Tensor:
@@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
8380
operation: DecoderActivation = DecoderActivation.IDENTITY
8481

8582
def __post_init__(self):
86-
super().__post_init__()
8783
if self.operation not in [
8884
DecoderActivation.RELU,
8985
DecoderActivation.SOFTPLUS,
@@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
163159
use_xavier_init: bool = True
164160

165161
def __post_init__(self):
166-
super().__init__()
167-
168162
try:
169163
last_activation = {
170164
DecoderActivation.RELU: torch.nn.ReLU(True),
@@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
284278
network: MLPWithInputSkips
285279

286280
def __post_init__(self):
287-
super().__post_init__()
288281
run_auto_creation(self)
289282

290283
def forward(

Diff for: pytorch3d/implicitron/models/implicit_function/idr_feature_field.py

-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
6666
encoding_dim: int = 0
6767

6868
def __post_init__(self):
69-
super().__init__()
70-
7169
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
7270

7371
self.embed_fn = None

Diff for: pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py

-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
5656
"""
5757

5858
def __post_init__(self):
59-
super().__init__()
6059
# The harmonic embedding layer converts input 3D coordinates
6160
# to a representation that is more suitable for
6261
# processing with a deep neural network.

Diff for: pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py

-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
4444
raymarch_function: Any = None
4545

4646
def __post_init__(self):
47-
super().__init__()
4847
self._harmonic_embedding = HarmonicEmbedding(
4948
self.n_harmonic_functions, append_input=True
5049
)
@@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
135134
ray_dir_in_camera_coords: bool = False
136135

137136
def __post_init__(self):
138-
super().__init__()
139137
self._harmonic_embedding = HarmonicEmbedding(
140138
self.n_harmonic_functions, append_input=True
141139
)
@@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
249247
xyz_in_camera_coords: bool = False
250248

251249
def __post_init__(self):
252-
super().__init__()
253250
raymarch_input_embedding_dim = (
254251
HarmonicEmbedding.get_output_dim_static(
255252
self.in_features,
@@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
335332
pixel_generator: SRNPixelGenerator
336333

337334
def __post_init__(self):
338-
super().__init__()
339335
run_auto_creation(self)
340336

341337
def create_raymarch_function(self) -> None:
@@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
393389
pixel_generator: SRNPixelGenerator
394390

395391
def __post_init__(self):
396-
super().__init__()
397392
run_auto_creation(self)
398393

399394
def create_hypernet(self) -> None:

Diff for: pytorch3d/implicitron/models/implicit_function/voxel_grid.py

-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
8181
)
8282

8383
def __post_init__(self):
84-
super().__init__()
8584
if 0 not in self.resolution_changes:
8685
raise ValueError("There has to be key `0` in `resolution_changes`.")
8786

@@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
857856
param_groups: Dict[str, str] = field(default_factory=lambda: {})
858857

859858
def __post_init__(self):
860-
super().__init__()
861859
run_auto_creation(self)
862860
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
863861
shapes = self.voxel_grid.get_shapes(epoch=0)

Diff for: pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py

-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
186186
volume_cropping_epochs: Tuple[int, ...] = ()
187187

188188
def __post_init__(self) -> None:
189-
super().__init__()
190189
run_auto_creation(self)
191190
# pyre-ignore[16]
192191
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()

Diff for: pytorch3d/implicitron/models/metrics.py

-6
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
2525
depend on the model's parameters.
2626
"""
2727

28-
def __post_init__(self) -> None:
29-
super().__init__()
30-
3128
def forward(
3229
self, model: Any, keys_prefix: str = "loss_", **kwargs
3330
) -> Dict[str, Any]:
@@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
5653
`forward()` method produces losses and other metrics.
5754
"""
5855

59-
def __post_init__(self) -> None:
60-
super().__init__()
61-
6256
def forward(
6357
self,
6458
raymarched: RendererOutput,

Diff for: pytorch3d/implicitron/models/model_dbir.py

-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
4141
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
4242
max_points: int = -1
4343

44-
def __post_init__(self):
45-
super().__init__()
46-
4744
def forward(
4845
self,
4946
*, # force keyword-only arguments

Diff for: pytorch3d/implicitron/models/renderer/base.py

-3
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
141141
Base class for all Renderer implementations.
142142
"""
143143

144-
def __init__(self) -> None:
145-
super().__init__()
146-
147144
def requires_object_mask(self) -> bool:
148145
"""
149146
Whether `forward` needs the object_mask.

Diff for: pytorch3d/implicitron/models/renderer/lstm_renderer.py

-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
5757
verbose: bool = False
5858

5959
def __post_init__(self):
60-
super().__init__()
6160
self._lstm = torch.nn.LSTMCell(
6261
input_size=self.n_feature_channels,
6362
hidden_size=self.hidden_size,

Diff for: pytorch3d/implicitron/models/renderer/multipass_ea.py

-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
9090
return_weights: bool = False
9191

9292
def __post_init__(self):
93-
super().__init__()
9493
self._refiners = {
9594
EvaluationMode.TRAINING: RayPointRefiner(
9695
n_pts_per_ray=self.n_pts_per_ray_fine_training,

Diff for: pytorch3d/implicitron/models/renderer/ray_point_refiner.py

-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
3838
random_sampling: bool
3939
add_input_samples: bool = True
4040

41-
def __post_init__(self) -> None:
42-
super().__init__()
43-
4441
def forward(
4542
self,
4643
input_ray_bundle: ImplicitronRayBundle,

Diff for: pytorch3d/implicitron/models/renderer/ray_sampler.py

-5
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
2020
Base class for ray samplers.
2121
"""
2222

23-
def __init__(self):
24-
super().__init__()
25-
2623
def forward(
2724
self,
2825
cameras: CamerasBase,
@@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
10299
stratified_point_sampling_evaluation: bool = False
103100

104101
def __post_init__(self):
105-
super().__init__()
106-
107102
if (self.n_rays_per_image_sampled_from_mask is not None) and (
108103
self.n_rays_total_training is not None
109104
):

Diff for: pytorch3d/implicitron/models/renderer/ray_tracing.py

-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
4343
n_steps: int = 100
4444
n_secant_steps: int = 8
4545

46-
def __post_init__(self):
47-
super().__init__()
48-
4946
def forward(
5047
self,
5148
sdf: Callable[[torch.Tensor], torch.Tensor],

Diff for: pytorch3d/implicitron/models/renderer/raymarcher.py

-5
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
2222
and marching along them in order to generate a feature render.
2323
"""
2424

25-
def __init__(self):
26-
super().__init__()
27-
2825
def forward(
2926
self,
3027
rays_densities: torch.Tensor,
@@ -98,8 +95,6 @@ def __post_init__(self):
9895
surface_thickness: Denotes the overlap between the absorption
9996
function and the density function.
10097
"""
101-
super().__init__()
102-
10398
bg_color = torch.tensor(self.bg_color)
10499
if bg_color.ndim != 1:
105100
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")

Diff for: pytorch3d/implicitron/models/renderer/sdf_renderer.py

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
3535
def __post_init__(
3636
self,
3737
):
38-
super().__init__()
3938
render_features_dimensions = self.render_features_dimensions
4039
if len(self.bg_color) not in [1, render_features_dimensions]:
4140
raise ValueError(

Diff for: pytorch3d/implicitron/models/view_pooler/feature_aggregator.py

-12
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
118118
the outputs.
119119
"""
120120

121-
def __post_init__(self):
122-
super().__init__()
123-
124121
def get_aggregated_feature_dim(
125122
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
126123
):
@@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
181178
ReductionFunction.STD,
182179
)
183180

184-
def __post_init__(self):
185-
super().__init__()
186-
187181
def get_aggregated_feature_dim(
188182
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
189183
):
@@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
275269
weight_by_ray_angle_gamma: float = 1.0
276270
min_ray_angle_weight: float = 0.1
277271

278-
def __post_init__(self):
279-
super().__init__()
280-
281272
def get_aggregated_feature_dim(
282273
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
283274
):
@@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
377368
weight_by_ray_angle_gamma: float = 1.0
378369
min_ray_angle_weight: float = 0.1
379370

380-
def __post_init__(self):
381-
super().__init__()
382-
383371
def get_aggregated_feature_dim(
384372
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
385373
):

Diff for: pytorch3d/implicitron/models/view_pooler/view_pooler.py

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
3838
feature_aggregator: FeatureAggregatorBase
3939

4040
def __post_init__(self):
41-
super().__init__()
4241
run_auto_creation(self)
4342

4443
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):

Diff for: pytorch3d/implicitron/models/view_pooler/view_sampler.py

-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
2929
masked_sampling: bool = False
3030
sampling_mode: str = "bilinear"
3131

32-
def __post_init__(self):
33-
super().__init__()
34-
3532
def forward(
3633
self,
3734
*, # force kw args

0 commit comments

Comments
 (0)