-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNorm.py
99 lines (84 loc) · 2.92 KB
/
Norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from typing import List, Callable
from torch_geometric.nn.norm import GraphNorm as PygGN, InstanceNorm as PygIN
from torch import Tensor
import torch.nn as nn
def expandbatch(x: Tensor, batch: Tensor):
if batch is None:
return x.flatten(0, 1), None
else:
R = x.shape[0]
N = batch[-1] + 1
offset = N*torch.arange(R, device=x.device).reshape(-1, 1)
batch = batch.unsqueeze(0) + offset
return x.flatten(0, 1), batch.flatten()
class NormMomentumScheduler:
def __init__(self, mfunc: Callable, initmomentum: float, normtype=nn.BatchNorm1d) -> None:
super().__init__()
self.normtype = normtype
self.mfunc = mfunc
self.epoch = 0
self.initmomentum = initmomentum
def step(self, model: nn.Module):
ratio = self.mfunc(self.epoch)
if 1-1e-6<ratio<1+1e-6:
return self.initmomentum
curm = self.initmomentum*ratio
self.epoch += 1
for mod in model.modules():
if type(mod) is self.normtype:
mod.momentum = curm
return curm
class NoneNorm(nn.Module):
def __init__(self, dim=0, normparam=0) -> None:
super().__init__()
self.num_features = dim
def forward(self, x):
return x
class BatchNorm(nn.Module):
def __init__(self, dim, normparam=0.1) -> None:
super().__init__()
self.num_features = dim
self.norm = nn.BatchNorm1d(dim, momentum=normparam)
def forward(self, x: Tensor):
if x.dim() == 2:
return self.norm(x)
elif x.dim() > 2:
shape = x.shape
x = self.norm(x.flatten(0, -2)).reshape(shape)
return x
else:
raise NotImplementedError
class LayerNorm(nn.Module):
def __init__(self, dim, normparam=0.1) -> None:
super().__init__()
self.num_features = dim
self.norm = nn.LayerNorm(dim)
def forward(self, x: Tensor):
return self.norm(x)
class InstanceNorm(nn.Module):
def __init__(self, dim, normparam=0.1) -> None:
super().__init__()
self.norm = PygIN(dim, momentum=normparam)
self.num_features = dim
def forward(self, x: Tensor):
if x.dim() == 2:
return self.norm(x)
elif x.dim() > 2:
shape = x.shape
x = self.norm(x.flatten(0, -2)).reshape(shape)
return x
else:
raise NotImplementedError
normdict = {"bn": BatchNorm, "ln": LayerNorm, "in": InstanceNorm, "none": NoneNorm}
basenormdict = {"bn": nn.BatchNorm1d, "ln": None, "in": PygIN, "gn": None, "none": None}
if __name__ == "__main__":
x = torch.randn((3,4,5))
batch = torch.tensor((0,0,1,2))
x, batch = expandbatch(x, batch)
print(x.shape, batch)
x = torch.randn((3,4,5))
batch = None
x, batch = expandbatch(x, batch)
print(x.shape, batch)
print(list(InstanceNorm(1000).modules()))