Skip to content

Commit

Permalink
refactor: Removed contiguous params since torch>=1.7.0 includes it (#183
Browse files Browse the repository at this point in the history
)

* refactor: Removed contiguous params

* chore: Updated dependencies

* chore: Removed last deps

* test: Fixed trainer
  • Loading branch information
frgfm authored Dec 25, 2021
1 parent 36e0383 commit 44f1168
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 17 deletions.
21 changes: 7 additions & 14 deletions holocron/trainer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

import math
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from contiguous_params import ContiguousParams
from fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import ConsoleMasterBar
from torch import Tensor, nn
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(
self.epoch = 0
self.min_loss = math.inf
self.gpu = gpu
self._params: Optional[List[ContiguousParams]] = None
self._params: Tuple[Sequence[torch.nn.Parameter], Sequence[torch.nn.Parameter]] = ([], [])
self.lr_recorder: List[float] = []
self.loss_recorder: List[float] = []
self.set_device(gpu)
Expand Down Expand Up @@ -188,12 +187,9 @@ def _set_params(self, norm_weight_decay: Optional[float] = None) -> None:
raise AssertionError("All parameters are frozen")

if norm_weight_decay is None:
self._params = [ContiguousParams([p for p in self.model.parameters() if p.requires_grad])]
self._params = [p for p in self.model.parameters() if p.requires_grad], []
else:
self._params = [
ContiguousParams(_params) if len(_params) > 0 else None
for _params in split_normalization_params(self.model)
]
self._params = split_normalization_params(self.model)

def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> None:
"""Reset the target params of the optimizer"""
Expand All @@ -204,14 +200,14 @@ def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> No
# Split it if norm layers needs custom WD
if norm_weight_decay is None:
self.optimizer.add_param_group(
dict(params=self._params[0].contiguous()) # type: ignore[index]
dict(params=self._params[0]) # type: ignore[index]
)
else:
wd_groups = [norm_weight_decay, self.optimizer.defaults.get('weight_decay', 0)]
for _params, _wd in zip(self._params, wd_groups): # type: ignore[arg-type]
if _params:
if len(_params) > 0:
self.optimizer.add_param_group(
dict(params=_params.contiguous(), weight_decay=_wd)
dict(params=_params, weight_decay=_wd)
)

@torch.inference_mode()
Expand Down Expand Up @@ -261,9 +257,6 @@ def fit_n_epochs(
for _ in mb:

self._fit_epoch(mb)
# Check whether ops invalidated the buffer
for _group in self._params: # type: ignore[union-attr]
_group.assert_buffer_is_valid()
eval_metrics = self.evaluate()

# master bar
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ tqdm>=4.1.0
numpy>=1.17.2
fastprogress>=1.0.0
matplotlib>=3.0.0
contiguous-params==1.0.0
Pillow>=8.4.0
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"numpy>=1.17.2",
"fastprogress>=1.0.0",
"matplotlib>=3.0.0",
"contiguous-params==1.0.0",
"Pillow>=8.4.0", # cf. https://github.com/pytorch/vision/issues/4934
# Testing
"pytest>=5.3.2",
Expand Down Expand Up @@ -78,7 +77,6 @@ def deps_list(*pkgs):
deps["numpy"],
deps["fastprogress"],
deps["matplotlib"],
deps["contiguous-params"],
deps["Pillow"],
]

Expand Down
3 changes: 3 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def _test_trainer(
lr: float = 1e-3
) -> None:

learner.model = trainer.utils.freeze_model(learner.model.train(), freeze_until)
learner._reset_opt(lr)
# Update param groups & LR
learner.save(learner.output_file)
checkpoint = torch.load(learner.output_file, map_location='cpu')
model_w = learner.model.state_dict()[ref_param].clone()
Expand Down

0 comments on commit 44f1168

Please sign in to comment.