Skip to content

Commit c759fc5

Browse files
Dejan Kovachevfacebook-github-bot
Dejan Kovachev
authored andcommitted
Hard population of registry system with pre_expand
Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports. Reviewed By: bottler Differential Revision: D44504122 fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62
1 parent 813e941 commit c759fc5

File tree

5 files changed

+117
-27
lines changed

5 files changed

+117
-27
lines changed

pytorch3d/implicitron/dataset/data_source.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,8 @@
1313
)
1414
from pytorch3d.renderer.cameras import CamerasBase
1515

16-
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
1716
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
1817
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
19-
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
20-
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
21-
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
22-
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
2318

2419

2520
class DataSourceBase(ReplaceableBase):
@@ -60,6 +55,26 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
6055
data_loader_map_provider: DataLoaderMapProviderBase
6156
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
6257

58+
@classmethod
59+
def pre_expand(cls) -> None:
60+
# use try/finally to bypass cinder's lazy imports
61+
try:
62+
from .blender_dataset_map_provider import ( # noqa: F401
63+
BlenderDatasetMapProvider,
64+
)
65+
from .json_index_dataset_map_provider import ( # noqa: F401
66+
JsonIndexDatasetMapProvider,
67+
)
68+
from .json_index_dataset_map_provider_v2 import ( # noqa: F401
69+
JsonIndexDatasetMapProviderV2,
70+
)
71+
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa: F401
72+
from .rendered_mesh_dataset_map_provider import ( # noqa: F401
73+
RenderedMeshDatasetMapProvider,
74+
)
75+
finally:
76+
pass
77+
6378
def __post_init__(self):
6479
run_auto_creation(self)
6580
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None

pytorch3d/implicitron/models/generic_model.py

+31-22
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,8 @@
2020
ImplicitronRender,
2121
)
2222
from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase
23-
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa
24-
ResNetFeatureExtractor,
25-
)
2623
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
2724
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
28-
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa
29-
IdrFeatureField,
30-
)
31-
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
32-
NeRFormerImplicitFunction,
33-
)
34-
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
35-
SRNHyperNetImplicitFunction,
36-
)
37-
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
38-
VoxelGridImplicitFunction,
39-
)
4025
from pytorch3d.implicitron.models.metrics import (
4126
RegularizationMetricsBase,
4227
ViewMetricsBase,
@@ -50,14 +35,7 @@
5035
RendererOutput,
5136
RenderSamplingMode,
5237
)
53-
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer # noqa
54-
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
55-
MultiPassEmissionAbsorptionRenderer,
56-
)
5738
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
58-
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
59-
SignedDistanceFunctionRenderer,
60-
)
6139

6240
from pytorch3d.implicitron.models.utils import (
6341
apply_chunked,
@@ -315,6 +293,37 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
315293
]
316294
)
317295

296+
@classmethod
297+
def pre_expand(cls) -> None:
298+
# use try/finally to bypass cinder's lazy imports
299+
try:
300+
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa: F401, B950
301+
ResNetFeatureExtractor,
302+
)
303+
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
304+
IdrFeatureField,
305+
)
306+
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
307+
NeRFormerImplicitFunction,
308+
)
309+
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
310+
SRNHyperNetImplicitFunction,
311+
)
312+
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa: F401, B950
313+
VoxelGridImplicitFunction,
314+
)
315+
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
316+
LSTMRenderer,
317+
)
318+
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
319+
MultiPassEmissionAbsorptionRenderer,
320+
)
321+
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
322+
SignedDistanceFunctionRenderer,
323+
)
324+
finally:
325+
pass
326+
318327
def __post_init__(self):
319328
if self.view_pooler_enabled:
320329
if self.image_feature_extractor_class_type is None:

pytorch3d/implicitron/models/overfit_model.py

+25
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,31 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
258258
]
259259
)
260260

261+
@classmethod
262+
def pre_expand(cls) -> None:
263+
# use try/finally to bypass cinder's lazy imports
264+
try:
265+
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
266+
IdrFeatureField,
267+
)
268+
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
269+
NeuralRadianceFieldImplicitFunction,
270+
)
271+
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
272+
SRNImplicitFunction,
273+
)
274+
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
275+
LSTMRenderer,
276+
)
277+
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa: F401
278+
MultiPassEmissionAbsorptionRenderer,
279+
)
280+
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
281+
SignedDistanceFunctionRenderer,
282+
)
283+
finally:
284+
pass
285+
261286
def __post_init__(self):
262287
# The attribute will be filled by run_auto_creation
263288
run_auto_creation(self)

pytorch3d/implicitron/tools/config.py

+7
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __post_init__(self):
185185
IMPL_SUFFIX: str = "_impl"
186186
TWEAK_SUFFIX: str = "_tweak_args"
187187
_DATACLASS_INIT: str = "__dataclass_own_init__"
188+
PRE_EXPAND_NAME: str = "pre_expand"
188189

189190

190191
class ReplaceableBase:
@@ -838,6 +839,9 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
838839
In addition, if the class inherits torch.nn.Module, the generated __init__ will
839840
call torch.nn.Module's __init__ before doing anything else.
840841
842+
Before any transformation of the class, if the class has a classmethod called
843+
`pre_expand`, it will be called with no arguments.
844+
841845
Note that although the *_args members are intended to have type DictConfig, they
842846
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
843847
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
@@ -858,6 +862,9 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
858862
if _is_actually_dataclass(some_class):
859863
return some_class
860864

865+
if hasattr(some_class, PRE_EXPAND_NAME):
866+
getattr(some_class, PRE_EXPAND_NAME)()
867+
861868
# The functions this class's run_auto_creation will run.
862869
creation_functions: List[str] = []
863870
# The classes which this type knows about from the registry

tests/implicitron/test_config.py

+34
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dataclasses import dataclass, field, is_dataclass
1111
from enum import Enum
1212
from typing import Any, Dict, List, Optional, Tuple
13+
from unittest.mock import Mock
1314

1415
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
1516
from pytorch3d.implicitron.tools.config import (
@@ -805,6 +806,39 @@ def __post_init__(self):
805806

806807
self.assertEqual(control_args, ["Orange", "Orange", True, True])
807808

809+
def test_pre_expand(self):
810+
# Check that the precreate method of a class is called once before
811+
# when expand_args_fields is called on the class.
812+
813+
class A(Configurable):
814+
n: int = 9
815+
816+
@classmethod
817+
def pre_expand(cls):
818+
pass
819+
820+
A.pre_expand = Mock()
821+
expand_args_fields(A)
822+
A.pre_expand.assert_called()
823+
824+
def test_pre_expand_replaceable(self):
825+
# Check that the precreate method of a class is called once before
826+
# when expand_args_fields is called on the class.
827+
828+
class A(ReplaceableBase):
829+
pass
830+
831+
@classmethod
832+
def pre_expand(cls):
833+
pass
834+
835+
class A1(A):
836+
n: 9
837+
838+
A.pre_expand = Mock()
839+
expand_args_fields(A1)
840+
A.pre_expand.assert_called()
841+
808842

809843
@dataclass(eq=False)
810844
class MockDataclass:

0 commit comments

Comments
 (0)