-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest.py
95 lines (70 loc) · 2.56 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os.path
import typing
import typing
import warnings
from pprint import pprint
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import RichModelSummary
from pytorch_lightning.profilers import AdvancedProfiler, PyTorchProfiler
from pytorch_lightning.strategies import DDPStrategy
from torch import nn
from torch.distributed import fsdp as fsdp_
from core import LightningSystem, data
from core.data.augmentation import StemAugmentor
from utils.config import dict_to_trainer_kwargs, read_nested_yaml
torch.set_float32_matmul_precision("medium")
from torch._dynamo import config
config.verbose = True
config.cache_size_limit = 1024
from typing import Any
import copy
import torchmetrics as tm
torch.backends.cudnn.benchmark = True
def test(
config_path: str,
ckpt_path: typing.Optional[str] = None,
**kwargs: Any
) -> None:
# if torch.cuda.device_count() > 1:
# raise RuntimeError("Testing should only be done on a single GPU for reproducibility.")
config = read_nested_yaml(config_path)
config_ = copy.deepcopy(config)
pprint(config)
pl.seed_everything(seed=config["seed"], workers=True)
assert isinstance(config["data"], dict)
assert isinstance(config["data"]["data"], dict)
dmcls = config["data"]["data"].pop("datamodule")
assert isinstance(dmcls, str)
datamodule = data.__dict__[dmcls](**config["data"]["data"])
assert isinstance(config["trainer"], dict)
trainer_kwargs = dict_to_trainer_kwargs(config["trainer"])
loss_adjustment = 1.0
strategy = "auto"
trainer_kwargs['callbacks'].append(RichModelSummary(max_depth=3))
trainer = pl.Trainer(
# devices=,
**trainer_kwargs, # type: ignore[arg-type]
strategy=strategy,
**kwargs,
)
assert isinstance(config["system"], dict)
if "augmentation" in config["system"]:
pass
elif "augmentation" in config["data"]:
warnings.warn(
"Augmentation should now be put under system.augmentation "
"instead of data.augmentation.",
DeprecationWarning,
)
config["system"]["augmentation"] = config["data"]["augmentation"]
else:
config["system"]["augmentation"] = None
model = LightningSystem(config["system"], loss_adjustment=loss_adjustment)
if model.fader is not None:
model.fader = None
model.attach_fader(force_reattach=True)
trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)
if __name__ == "__main__":
import fire
fire.Fire(test)