-
Notifications
You must be signed in to change notification settings - Fork 9
/
Net.py
87 lines (68 loc) · 2.02 KB
/
Net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import abc
from collections import OrderedDict
from torch import cos, sin
class MLP(torch.nn.Module):
def __init__(self, seq, name='mlp'):
super().__init__()
self.layers = OrderedDict()
for i in range(len(seq) - 1):
self.layers['{}_{}'.format(name, i)] = torch.nn.Linear(seq[i], seq[i + 1])
self.layers = torch.nn.ModuleDict(self.layers)
def forward(self, x):
l = len(self.layers)
i = 0
for name, layer in self.layers.items():
x = layer(x)
if i == l - 1: break
i += 1
x = torch.tanh(x)
return x
class MLP_sin(MLP):
def forward(self, x):
l = len(self.layers)
i = 0
for name, layer in self.layers.items():
x = layer(x)
if i == l - 2: break
x = torch.sin(x)
return x
class Net(torch.nn.Module):
def __init__(self, seq):
super().__init__()
self.mlp = MLP(seq)
def forward(self, x):
return self.mlp(x)
class Basis_Net(torch.nn.Module):
def __init__(self, seq, basis):
super().__init__()
self.basis_num = seq[-1]
self.mlp = MLP(seq)
self.basis = basis
def forward(self, x):
s = self.basis(x)
x = self.mlp(x)
return (x * s).sum(1).reshape((-1, 1))
class Basis_Net_Time(Basis_Net):
def forward(self, x):
s = self.basis(x[:, :1])
x = self.mlp(x)
return (x * s).sum(1).reshape((-1, 1))
class Sphere_Net(Net):
def forward(self, x):
x = coordinates_get_3d(x)
x = self.mlp(x)
return x
class SPH_Sphere_Net(Basis_Net):
def forward(self, x):
s = self.basis(x)
x = coordinates_get_3d(x)
x = self.mlp(x)
x = (x * s).sum(dim=1).reshape((-1, 1))
return x
def coordinates_get_3d(x):
u = torch.sin(x[:, :1]) * torch.sin(x[:, 1:])
v = torch.sin(x[:, :1]) * torch.cos(x[:, 1:])
w = torch.cos(x[:, :1])
x = torch.cat([u, v, w], dim=1)
return x