Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Removed contiguous params since torch>=1.7.0 includes it #183

Merged
merged 4 commits into from
Dec 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions holocron/trainer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

import math
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from contiguous_params import ContiguousParams
from fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import ConsoleMasterBar
from torch import Tensor, nn
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(
self.epoch = 0
self.min_loss = math.inf
self.gpu = gpu
self._params: Optional[List[ContiguousParams]] = None
self._params: Tuple[Sequence[torch.nn.Parameter], Sequence[torch.nn.Parameter]] = ([], [])
self.lr_recorder: List[float] = []
self.loss_recorder: List[float] = []
self.set_device(gpu)
Expand Down Expand Up @@ -188,12 +187,9 @@ def _set_params(self, norm_weight_decay: Optional[float] = None) -> None:
raise AssertionError("All parameters are frozen")

if norm_weight_decay is None:
self._params = [ContiguousParams([p for p in self.model.parameters() if p.requires_grad])]
self._params = [p for p in self.model.parameters() if p.requires_grad], []
else:
self._params = [
ContiguousParams(_params) if len(_params) > 0 else None
for _params in split_normalization_params(self.model)
]
self._params = split_normalization_params(self.model)

def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> None:
"""Reset the target params of the optimizer"""
Expand All @@ -204,14 +200,14 @@ def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> No
# Split it if norm layers needs custom WD
if norm_weight_decay is None:
self.optimizer.add_param_group(
dict(params=self._params[0].contiguous()) # type: ignore[index]
dict(params=self._params[0]) # type: ignore[index]
)
else:
wd_groups = [norm_weight_decay, self.optimizer.defaults.get('weight_decay', 0)]
for _params, _wd in zip(self._params, wd_groups): # type: ignore[arg-type]
if _params:
if len(_params) > 0:
self.optimizer.add_param_group(
dict(params=_params.contiguous(), weight_decay=_wd)
dict(params=_params, weight_decay=_wd)
)

@torch.inference_mode()
Expand Down Expand Up @@ -261,9 +257,6 @@ def fit_n_epochs(
for _ in mb:

self._fit_epoch(mb)
# Check whether ops invalidated the buffer
for _group in self._params: # type: ignore[union-attr]
_group.assert_buffer_is_valid()
eval_metrics = self.evaluate()

# master bar
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ tqdm>=4.1.0
numpy>=1.17.2
fastprogress>=1.0.0
matplotlib>=3.0.0
contiguous-params==1.0.0
Pillow>=8.4.0
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"numpy>=1.17.2",
"fastprogress>=1.0.0",
"matplotlib>=3.0.0",
"contiguous-params==1.0.0",
"Pillow>=8.4.0", # cf. https://github.com/pytorch/vision/issues/4934
# Testing
"pytest>=5.3.2",
Expand Down Expand Up @@ -78,7 +77,6 @@ def deps_list(*pkgs):
deps["numpy"],
deps["fastprogress"],
deps["matplotlib"],
deps["contiguous-params"],
deps["Pillow"],
]

Expand Down
3 changes: 3 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def _test_trainer(
lr: float = 1e-3
) -> None:

learner.model = trainer.utils.freeze_model(learner.model.train(), freeze_until)
learner._reset_opt(lr)
# Update param groups & LR
learner.save(learner.output_file)
checkpoint = torch.load(learner.output_file, map_location='cpu')
model_w = learner.model.state_dict()[ref_param].clone()
Expand Down