-
Notifications
You must be signed in to change notification settings - Fork 321
/
test_vision_models.py
56 lines (45 loc) · 1.27 KB
/
test_vision_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule
from pl_bolts.models import GPT2, ImageGPT, UNet
def test_igpt(tmpdir):
pl.seed_everything(0)
dm = MNISTDataModule(tmpdir, normalize=False)
model = ImageGPT(datamodule=dm)
trainer = pl.Trainer(
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
)
trainer.fit(model)
trainer.test()
assert trainer.callback_metrics["test_loss"] < 1.7
dm = FashionMNISTDataModule(tmpdir, num_workers=1)
model = ImageGPT(classify=True, datamodule=dm)
trainer = pl.Trainer(
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
)
trainer.fit(model)
def test_gpt2(tmpdir):
seq_len = 17
batch_size = 32
vocab_size = 16
x = torch.randint(0, vocab_size, (seq_len, batch_size))
model = GPT2(
embed_dim=16,
heads=2,
layers=2,
num_positions=seq_len,
vocab_size=vocab_size,
num_classes=10,
)
model(x)
def test_unet(tmpdir):
x = torch.rand(10, 3, 28, 28)
model = UNet(num_classes=2)
y = model(x)
assert y.shape == torch.Size([10, 2, 28, 28])