Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting rid of "module." heritage #1184

Merged
merged 25 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2613ece
Getting rid of "module." heritage
BloodAxe Jun 16, 2023
719c140
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 22, 2023
80b17b7
Remove import of "WrappedModel",
BloodAxe Jun 22, 2023
52886c9
Merge remote-tracking branch 'origin/feature/SG-000-fix-module' into …
BloodAxe Jun 22, 2023
c649db8
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 26, 2023
7e72376
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
ba840fe
Merge remote-tracking branch 'origin/feature/SG-000-fix-module' into …
BloodAxe Jun 26, 2023
66618b7
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
8025b83
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
dc4924d
Remove remaining usages of .net.module
BloodAxe Jun 26, 2023
74f4259
Merge branch 'master' into feature/SG-000-fix-module
shaydeci Jun 26, 2023
84b9f34
Put back WrappedModel class in place, but add deprecation warning whe…
BloodAxe Jun 27, 2023
553192c
Fix _yolox_ckpt_solver (updating condition to account missing "module.")
BloodAxe Jun 27, 2023
c2e0a10
Change python3.8 to python
BloodAxe Jun 27, 2023
d0ee8c8
Merge remote-tracking branch 'origin/feature/SG-000-fix-module' into …
BloodAxe Jun 27, 2023
49bd2ff
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 27, 2023
341cf32
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 27, 2023
6e19867
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
2781de7
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
03540e6
Reorder tests
BloodAxe Jun 28, 2023
c340da8
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
af468c3
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
d15eb30
Add missing unwrap_model after merge with master
BloodAxe Jun 28, 2023
4040f87
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
2df464a
Merge branch 'master' into feature/SG-000-fix-module
BloodAxe Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Makefile
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ yolo_nas_integration_tests:
python -m unittest tests/integration_tests/yolo_nas_integration_test.py

recipe_accuracy_tests:
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
python src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
python src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py
9 changes: 5 additions & 4 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
from super_gradients.training.utils.distributed_training_utils import setup_device
from super_gradients.training.utils.ema import KDModelEMA
from super_gradients.training.utils.utils import unwrap_model

logger = get_logger(__name__)

Expand Down Expand Up @@ -211,7 +212,7 @@ def _load_checkpoint_to_model(self):
the entire KD network following the same logic as in Trainer.
"""
teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path")
teacher_net = self.net.module.teacher
teacher_net = unwrap_model(self.net).teacher

if teacher_checkpoint_path is not None:

Expand Down Expand Up @@ -271,12 +272,12 @@ def _save_best_checkpoint(self, epoch, state):
Overrides parent best_ckpt saving to modify the state dict so that we only save the student.
"""
if self.ema:
best_net = core_utils.WrappedModel(self.ema_model.ema.module.student)
best_net = self.ema_model.ema.student
state.pop("ema_net")
else:
best_net = core_utils.WrappedModel(self.net.module.student)
best_net = self.net.student

state["net"] = best_net.state_dict()
state["net"] = unwrap_model(best_net).state_dict()
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

def train(
Expand Down
41 changes: 21 additions & 20 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from super_gradients.training.utils.ema import ModelEMA
from super_gradients.training.utils.optimizer_utils import build_optimizer
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
from super_gradients.training.utils.utils import fuzzy_idx_in_list
from super_gradients.training.utils.utils import fuzzy_idx_in_list, unwrap_model
from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
from super_gradients.training.metrics import Accuracy, Top5
from super_gradients.training.utils import random_seed
Expand Down Expand Up @@ -396,9 +396,6 @@ def _net_to_device(self):
local_rank = int(device_config.device.split(":")[1])
self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

else:
self.net = core_utils.WrappedModel(self.net)

def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
"""
train_epoch - A single epoch training procedure
Expand Down Expand Up @@ -584,7 +581,7 @@ def _save_checkpoint(
"""
# WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return
if validation_results_tuple is None:
self.sg_logger.add_checkpoint(tag="ckpt_latest_weights_only.pth", state_dict={"net": self.net.state_dict()}, global_step=epoch)
self.sg_logger.add_checkpoint(tag="ckpt_latest_weights_only.pth", state_dict={"net": unwrap_model(self.net).state_dict()}, global_step=epoch)
return

# COMPUTE THE CURRENT metric
Expand All @@ -596,17 +593,18 @@ def _save_checkpoint(
)

# BUILD THE state_dict
state = {"net": self.net.state_dict(), "acc": metric, "epoch": epoch}
state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch}

if optimizer is not None:
state["optimizer_state_dict"] = optimizer.state_dict()

if self.scaler is not None:
state["scaler_state_dict"] = self.scaler.state_dict()

if self.ema:
state["ema_net"] = self.ema_model.ema.state_dict()
state["ema_net"] = unwrap_model(self.ema_model.ema).state_dict()

if isinstance(self.net.module, HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
if isinstance(unwrap_model(self.net), HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
state["processing_params"] = self.valid_loader.dataset.get_dataset_preprocessing_params()

# SAVES CURRENT MODEL AS ckpt_latest
Expand All @@ -630,7 +628,7 @@ def _save_checkpoint(
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))

if self.training_params.average_best_models:
net_for_averaging = self.ema_model.ema if self.ema else self.net
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
state["net"] = self.model_weight_averaging.get_average_model(net_for_averaging, validation_results_tuple=validation_results_tuple)
self.sg_logger.add_checkpoint(tag=self.average_model_checkpoint_filename, state_dict=state, global_step=epoch)

Expand All @@ -645,7 +643,7 @@ def _prep_net_for_train(self) -> None:
self._net_to_device()

# SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE
self.update_param_groups = hasattr(self.net.module, "update_param_groups")
self.update_param_groups = hasattr(unwrap_model(self.net), "update_param_groups")

self.checkpoint = {}
self.strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON)
Expand Down Expand Up @@ -1154,7 +1152,9 @@ def forward(self, inputs, targets):
if not self.ddp_silent_mode:
if self.training_params.dataset_statistics:
dataset_statistics_logger = DatasetStatisticsTensorboardLogger(self.sg_logger)
dataset_statistics_logger.analyze(self.train_loader, all_classes=self.classes, title="Train-set", anchors=self.net.module.arch_params.anchors)
dataset_statistics_logger.analyze(
self.train_loader, all_classes=self.classes, title="Train-set", anchors=unwrap_model(self.net).arch_params.anchors
)
dataset_statistics_logger.analyze(self.valid_loader, all_classes=self.classes, title="val-set")

sg_trainer_utils.log_uncaught_exceptions(logger)
Expand All @@ -1168,7 +1168,7 @@ def forward(self, inputs, targets):
if isinstance(self.training_params.optimizer, str) or (
inspect.isclass(self.training_params.optimizer) and issubclass(self.training_params.optimizer, torch.optim.Optimizer)
):
self.optimizer = build_optimizer(net=self.net, lr=self.training_params.initial_lr, training_params=self.training_params)
self.optimizer = build_optimizer(net=unwrap_model(self.net), lr=self.training_params.initial_lr, training_params=self.training_params)
elif isinstance(self.training_params.optimizer, torch.optim.Optimizer):
self.optimizer = self.training_params.optimizer
else:
Expand Down Expand Up @@ -1285,7 +1285,7 @@ def forward(self, inputs, targets):
num_gpus=get_world_size(),
)

# model switch - we replace self.net.module with the ema model for the testing and saving part
# model switch - we replace self.net with the ema model for the testing and saving part
# and then switch it back before the next training epoch
if self.ema:
self.ema_model.update_attr(self.net)
Expand Down Expand Up @@ -1345,9 +1345,9 @@ def forward(self, inputs, targets):
self.sg_logger.close()

def _set_net_preprocessing_from_valid_loader(self):
if isinstance(self.net.module, HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
if isinstance(unwrap_model(self.net), HasPredict) and isinstance(self.valid_loader.dataset, HasPreprocessingParams):
try:
self.net.module.set_dataset_processing_params(**self.valid_loader.dataset.get_dataset_preprocessing_params())
unwrap_model(self.net).set_dataset_processing_params(**self.valid_loader.dataset.get_dataset_preprocessing_params())
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the validation dataset:\n {e}.\n Before calling"
Expand Down Expand Up @@ -1403,7 +1403,7 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
def hook(module, _):
module.forward = MultiGPUModeAutocastWrapper(module.forward)

self.net.module.register_forward_pre_hook(hook=hook)
unwrap_model(self.net).register_forward_pre_hook(hook=hook)

if self.load_checkpoint:
scaler_state_dict = core_utils.get_param(self.checkpoint, "scaler_state_dict")
Expand All @@ -1429,7 +1429,7 @@ def _validate_final_average_model(self, cleanup_snapshots_pkl_file=False):
with wait_for_the_master(local_rank):
average_model_sd = read_ckpt_state_dict(average_model_ckpt_path)["net"]

self.net.load_state_dict(average_model_sd)
unwrap_model(self.net).load_state_dict(average_model_sd)
# testing the averaged model and save instead of best model if needed
averaged_model_results_tuple = self._validate_epoch(epoch=self.max_epochs)

Expand All @@ -1453,7 +1453,7 @@ def get_arch_params(self):

@property
def get_structure(self):
return self.net.module.structure
return unwrap_model(self.net).structure

@property
def get_architecture(self):
Expand Down Expand Up @@ -1485,7 +1485,8 @@ def _re_build_model(self, arch_params={}):

if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
logger.warning("Warning: distributed training is not supported in re_build_model()")
self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids()) if device_config.multi_gpu else core_utils.WrappedModel(self.net)
if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
self.net = torch.nn.DataParallel(self.net, device_ids=get_device_ids())

@property
def get_module(self):
Expand Down Expand Up @@ -1626,7 +1627,7 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
if "model_name" in get_callable_param_names(sg_logger_cls.__init__):
if sg_logger_params.get("model_name") is None:
# Use the model name used in `models.get(...)` if relevant
sg_logger_params["model_name"] = get_model_name(self.net.module)
sg_logger_params["model_name"] = get_model_name(unwrap_model(self.net))

if sg_logger_params["model_name"] is None:
raise ValueError(
Expand Down
2 changes: 0 additions & 2 deletions src/super_gradients/training/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from super_gradients.training.utils.utils import (
Timer,
HpmStruct,
WrappedModel,
convert_to_tensor,
get_param,
tensor_container_to_device,
Expand All @@ -17,7 +16,6 @@
__all__ = [
"Timer",
"HpmStruct",
"WrappedModel",
"convert_to_tensor",
"get_param",
"tensor_container_to_device",
Expand Down
16 changes: 10 additions & 6 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path

from super_gradients.training.utils.utils import unwrap_model

logger = get_logger(__name__)

Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_bat
self.atol = kwargs.get("atol", 1e-05)

def __call__(self, context: PhaseContext):
model = copy.deepcopy(context.net.module)
model = copy.deepcopy(unwrap_model(context.net))
model = model.cpu()
model.eval() # Put model into eval mode

Expand Down Expand Up @@ -204,12 +204,12 @@ def __call__(self, context: PhaseContext) -> None:
:param context: Training phase context
"""
try:
model = copy.deepcopy(context.net)
model = copy.deepcopy(unwrap_model(context.net))
model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
model_state_dict = torch.load(model_state_dict_path)["net"]
model.load_state_dict(state_dict=model_state_dict)

model = model.module.cpu()
model = model.cpu()
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(input_size=self.input_dimensions)

Expand Down Expand Up @@ -267,7 +267,9 @@ def perform_scheduling(self, context: PhaseContext):

def update_lr(self, optimizer, epoch, batch_idx=None):
if self.update_param_groups:
param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len)
param_groups = unwrap_model(self.net).update_param_groups(
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
)
optimizer.param_groups = param_groups
else:
# UPDATE THE OPTIMIZERS PARAMETER
Expand Down Expand Up @@ -373,7 +375,9 @@ def update_lr(self, optimizer, epoch, batch_idx=None):
:return:
"""
if self.update_param_groups:
param_groups = self.net.module.update_param_groups(optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len)
param_groups = unwrap_model(self.net).update_param_groups(
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
)
optimizer.param_groups = param_groups
else:
# UPDATE THE OPTIMIZERS PARAMETER
Expand Down
14 changes: 12 additions & 2 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.pretrained_models import MODEL_URLS
from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
from super_gradients.training.utils.utils import unwrap_model

try:
from torch.hub import download_url_to_file, load_state_dict_from_url
Expand Down Expand Up @@ -54,6 +55,13 @@ def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Uni
:return:
"""
state_dict = state_dict["net"] if "net" in state_dict else state_dict

# This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper
# and contains "module." prefix in all keys
# If all keys start with "module.", then we remove it.
if all([key.startswith("module.") for key in state_dict.keys()]):
state_dict = collections.OrderedDict([(key[7:], value) for key, value in state_dict.items()])

try:
strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
net.load_state_dict(state_dict, strict=strict_bool)
Expand Down Expand Up @@ -217,6 +225,8 @@ def load_checkpoint_to_model(
if isinstance(strict, str):
strict = StrictLoad(strict)

net = unwrap_model(net)

if load_backbone and not hasattr(net, "backbone"):
raise ValueError("No backbone attribute in net - Can't load backbone weights")

Expand All @@ -239,7 +249,7 @@ def load_checkpoint_to_model(
message_model = "model" if not load_backbone else "model's backbone"
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)

if (isinstance(net, HasPredict) or (hasattr(net, "module") and isinstance(net.module, HasPredict))) and load_processing_params:
if (isinstance(net, HasPredict)) and load_processing_params:
if "processing_params" not in checkpoint.keys():
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
try:
Expand Down Expand Up @@ -275,7 +285,7 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):

if (
ckpt_val.shape != model_val.shape
and ckpt_key == "module._backbone._modules_list.0.conv.conv.weight"
and (ckpt_key == "module._backbone._modules_list.0.conv.conv.weight" or ckpt_key == "_backbone._modules_list.0.conv.conv.weight")
and model_key == "_backbone._modules_list.0.conv.weight"
):
model_val.data[:, :, ::2, ::2] = ckpt_val.data[:, :3]
Expand Down
Loading