-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
170 lines (136 loc) · 5.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os.path
import typing
import typing
import warnings
from pprint import pprint
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import RichModelSummary
from pytorch_lightning.profilers import AdvancedProfiler, PyTorchProfiler
from pytorch_lightning.strategies import DDPStrategy
from torch import nn
from torch.distributed import fsdp as fsdp_
from core import LightningSystem, data
from core.data.augmentation import StemAugmentor
from utils.config import dict_to_trainer_kwargs, read_nested_yaml
torch.set_float32_matmul_precision("medium")
from torch._dynamo import config
config.verbose = True
config.cache_size_limit = 1024
from typing import Any
import copy
import torchmetrics as tm
torch.backends.cudnn.benchmark = True
# torch.autograd.set_detect_anomaly(True)
def train(
config_path: str,
ckpt_path: typing.Optional[str] = None,
adjust_loss_in_lieu_of_accumulate_grad_batches: bool = False,
strategy: typing.Optional[str] = "ddp",
just_validate: bool = False,
val_batch_size: typing.Optional[int] = None,
finetune: bool = False,
**kwargs: Any
) -> None:
config = read_nested_yaml(config_path)
config_ = copy.deepcopy(config)
pprint(config)
pl.seed_everything(seed=config["seed"], workers=True)
if just_validate:
config["trainer"]["logger"]["kwargs"]["name"] = "validation"
assert isinstance(config["data"], dict)
assert isinstance(config["data"]["data"], dict)
dmcls = config["data"]["data"].pop("datamodule")
assert isinstance(dmcls, str)
if val_batch_size is not None:
assert just_validate
config["data"]["data"]["batch_size"] = val_batch_size
datamodule = data.__dict__[dmcls](**config["data"]["data"])
assert isinstance(config["trainer"], dict)
trainer_kwargs = dict_to_trainer_kwargs(config["trainer"])
loss_adjustment = 1.0
if "effective_batch_size" in trainer_kwargs:
assert ("accumulate_grad_batches" not in trainer_kwargs) or (
trainer_kwargs["accumulate_grad_batches"] is None
)
val_batch_size = config["data"]["data"]["batch_size"]
gpu_count = torch.cuda.device_count()
assert isinstance(val_batch_size, int)
effective_batch_size = trainer_kwargs.pop("effective_batch_size")
assert isinstance(effective_batch_size, int)
assert effective_batch_size % (gpu_count * val_batch_size) == 0
accumulate_grad_batches = effective_batch_size // (
gpu_count * val_batch_size)
if adjust_loss_in_lieu_of_accumulate_grad_batches:
loss_adjustment = 1.0 / accumulate_grad_batches
use_static_graph = True
trainer_kwargs["accumulate_grad_batches"] = 1
else:
trainer_kwargs["accumulate_grad_batches"] = accumulate_grad_batches
print(
f"Batch size: {val_batch_size}. Requesting effective batch size: {effective_batch_size}."
)
print(
f"Accumulating gradients from {accumulate_grad_batches} batches."
)
use_static_graph = False
else:
use_static_graph = "accumulate_grad_batches" not in trainer_kwargs
if torch.cuda.device_count() == 1:
strategy = "auto"
else:
if strategy == "ddp":
strategy = DDPStrategy(
static_graph=use_static_graph, gradient_as_bucket_view=True
)
trainer_kwargs['callbacks'].append(RichModelSummary(max_depth=3))
trainer = pl.Trainer(
**trainer_kwargs, # type: ignore[arg-type]
strategy=strategy,
# profiler=AdvancedProfiler(filename="profiler.txt"),
**kwargs,
# profiler=PyTorchProfiler(with_modules=True, filename="profiler.txt")
)
assert isinstance(config["system"], dict)
if "augmentation" in config["system"]:
pass
elif "augmentation" in config["data"]:
warnings.warn(
"Augmentation should now be put under system.augmentation "
"instead of data.augmentation.",
DeprecationWarning,
)
config["system"]["augmentation"] = config["data"]["augmentation"]
else:
config["system"]["augmentation"] = None
model = LightningSystem(config["system"], loss_adjustment=loss_adjustment)
if finetune in ["dnr->musdb", "dnr->mne"]:
assert ckpt_path is not None
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = ckpt['state_dict']
missing, unexpected = model.load_state_dict(
state_dict,
strict=False
)
missing = set([m.split(".")[2] for m in missing])
unexpected = set([u.split(".")[2] for u in unexpected])
print(f"Missing keys: {missing}")
print(f"Unexpected keys: {unexpected}")
ckpt_path = None
# raise NotImplementedError
del ckpt
assert trainer.logger is not None
trainer.logger.log_hyperparams(config_)
trainer.logger.save()
if just_validate:
if ckpt_path is None:
ckpt_path = os.path.join(os.path.dirname(config_path), "checkpoints", "last.ckpt")
assert os.path.exists(ckpt_path)
trainer.validate(model, datamodule=datamodule, ckpt_path=ckpt_path)
else:
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
model.attach_fader()
trainer.test(model, datamodule=datamodule, ckpt_path="best")
if __name__ == "__main__":
import fire
fire.Fire(train)