diff --git a/holocron/trainer/core.py b/holocron/trainer/core.py index 9e6ce28e8..8b83257b5 100644 --- a/holocron/trainer/core.py +++ b/holocron/trainer/core.py @@ -5,12 +5,11 @@ import math from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import matplotlib.pyplot as plt import numpy as np import torch -from contiguous_params import ContiguousParams from fastprogress import master_bar, progress_bar from fastprogress.fastprogress import ConsoleMasterBar from torch import Tensor, nn @@ -58,7 +57,7 @@ def __init__( self.epoch = 0 self.min_loss = math.inf self.gpu = gpu - self._params: Optional[List[ContiguousParams]] = None + self._params: Tuple[Sequence[torch.nn.Parameter], Sequence[torch.nn.Parameter]] = ([], []) self.lr_recorder: List[float] = [] self.loss_recorder: List[float] = [] self.set_device(gpu) @@ -188,12 +187,9 @@ def _set_params(self, norm_weight_decay: Optional[float] = None) -> None: raise AssertionError("All parameters are frozen") if norm_weight_decay is None: - self._params = [ContiguousParams([p for p in self.model.parameters() if p.requires_grad])] + self._params = [p for p in self.model.parameters() if p.requires_grad], [] else: - self._params = [ - ContiguousParams(_params) if len(_params) > 0 else None - for _params in split_normalization_params(self.model) - ] + self._params = split_normalization_params(self.model) def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> None: """Reset the target params of the optimizer""" @@ -204,14 +200,14 @@ def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> No # Split it if norm layers needs custom WD if norm_weight_decay is None: self.optimizer.add_param_group( - dict(params=self._params[0].contiguous()) # type: ignore[index] + dict(params=self._params[0]) # type: ignore[index] ) else: wd_groups = [norm_weight_decay, self.optimizer.defaults.get('weight_decay', 0)] for _params, _wd in zip(self._params, wd_groups): # type: ignore[arg-type] - if _params: + if len(_params) > 0: self.optimizer.add_param_group( - dict(params=_params.contiguous(), weight_decay=_wd) + dict(params=_params, weight_decay=_wd) ) @torch.inference_mode() @@ -261,9 +257,6 @@ def fit_n_epochs( for _ in mb: self._fit_epoch(mb) - # Check whether ops invalidated the buffer - for _group in self._params: # type: ignore[union-attr] - _group.assert_buffer_is_valid() eval_metrics = self.evaluate() # master bar diff --git a/requirements.txt b/requirements.txt index 85a155bd5..9c0c9c545 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ tqdm>=4.1.0 numpy>=1.17.2 fastprogress>=1.0.0 matplotlib>=3.0.0 -contiguous-params==1.0.0 Pillow>=8.4.0 diff --git a/setup.py b/setup.py index 21771e1dd..fbfd2bc93 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,6 @@ "numpy>=1.17.2", "fastprogress>=1.0.0", "matplotlib>=3.0.0", - "contiguous-params==1.0.0", "Pillow>=8.4.0", # cf. https://github.com/pytorch/vision/issues/4934 # Testing "pytest>=5.3.2", @@ -78,7 +77,6 @@ def deps_list(*pkgs): deps["numpy"], deps["fastprogress"], deps["matplotlib"], - deps["contiguous-params"], deps["Pillow"], ] diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 3eee2859a..93b57c694 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -76,6 +76,9 @@ def _test_trainer( lr: float = 1e-3 ) -> None: + learner.model = trainer.utils.freeze_model(learner.model.train(), freeze_until) + learner._reset_opt(lr) + # Update param groups & LR learner.save(learner.output_file) checkpoint = torch.load(learner.output_file, map_location='cpu') model_w = learner.model.state_dict()[ref_param].clone()