Skip to content

Commit

Permalink
Merge branch 'master' into feature/SG-000-fix-module
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe authored Jun 28, 2023
2 parents c340da8 + ac25b2b commit af468c3
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
logger = get_logger(__name__)


def get_architecture(model_name: str, arch_params: HpmStruct, download_required_code: bool = True) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]:
def get_architecture(
model_name: str, arch_params: HpmStruct, download_required_code: bool = True, download_platform_weights: bool = True
) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]:
"""
Get the corresponding architecture class.
Expand All @@ -35,10 +37,15 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.
:param download_platform_weights: bool, when getting a model from the platform, whether to downlaod the pretrained weights as well.
In any other case this parameter will be ignored. (default=True).
:return:
- architecture_cls: Class of the model
- arch_params: Might be updated if loading from remote deci lab
- pretrained_weights_path: path to the pretrained weights from deci lab (None for local models).
- pretrained_weights_path: path to the pretrained weights from deci lab (None for local models or when deci
client is not enabled).
- is_remote: True if loading from remote deci lab
"""
pretrained_weights_path = None
Expand All @@ -62,8 +69,10 @@ def get_architecture(model_name: str, arch_params: HpmStruct, download_required_
if download_required_code: # Some extra code might be required to instantiate the arch params.
deci_client.download_and_load_model_additional_code(model_name, target_path=str(Path.cwd()))
_arch_params = hydra.utils.instantiate(_arch_params)

pretrained_weights_path = deci_client.get_model_weights(model_name)
if download_platform_weights:
pretrained_weights_path = deci_client.get_model_weights(model_name)
else:
pretrained_weights_path = None
model_name = _arch_params["model_name"]
del _arch_params["model_name"]
_arch_params = HpmStruct(**_arch_params)
Expand Down Expand Up @@ -99,8 +108,10 @@ def instantiate_model(
if arch_params is None:
arch_params = {}
arch_params = core_utils.HpmStruct(**arch_params)

architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture(model_name, arch_params, download_required_code)
download_platform_weights = isinstance(pretrained_weights, str) and pretrained_weights.startswith("platform/")
architecture_cls, arch_params, pretrained_weights_path, is_remote = get_architecture(
model_name, arch_params, download_required_code, download_platform_weights
)

if not issubclass(architecture_cls, SgModule):
net = architecture_cls(**arch_params.to_dict(include_schema=False))
Expand All @@ -117,18 +128,21 @@ def instantiate_model(
if pretrained_weights is None and num_classes is None:
raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")

if pretrained_weights:
if pretrained_weights and pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif pretrained_weights and pretrained_weights is None:
raise ValueError(f"Unknown pretrained_weights - couldn't find pretrained weights in {PRETRAINED_NUM_CLASSES.keys()} or platform.")

# Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take **kwargs instead
# Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take
# **kwargs instead
if "arch_params" not in get_callable_param_names(architecture_cls):
net = architecture_cls(**arch_params.to_dict(include_schema=False))
else:
net = architecture_cls(arch_params=arch_params)

if pretrained_weights:
if is_remote:
if is_remote and pretrained_weights_path:
load_pretrained_weights_local(net, model_name, pretrained_weights_path)
else:
load_pretrained_weights(net, model_name, pretrained_weights)
Expand Down

0 comments on commit af468c3

Please sign in to comment.