-
Notifications
You must be signed in to change notification settings - Fork 13
/
model.py
57 lines (49 loc) · 1.86 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from unet import UNet
def warp(img, flow):
_, _, H, W = img.size()
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
gridX = torch.tensor(gridX, requires_grad=False).cuda()
gridY = torch.tensor(gridY, requires_grad=False).cuda()
u = flow[:,0,:,:]
v = flow[:,1,:,:]
x = gridX.unsqueeze(0).expand_as(u).float()+u
y = gridY.unsqueeze(0).expand_as(v).float()+v
normx = 2*(x/W-0.5)
normy = 2*(y/H-0.5)
grid = torch.stack((normx,normy), dim=3)
warped = F.grid_sample(img, grid)
return warped
class Net(nn.Module):
def __init__(self,level=3):
super(Net, self).__init__()
self.Mask = UNet(16,2,4)
self.Flow_L = UNet(6,4,5)
self.refine_flow = UNet(10,4,4)
self.final = UNet(9,3,4)
def process(self,x0,x1,t):
x = torch.cat((x0,x1),1)
Flow = self.Flow_L(x)
Flow_0_1, Flow_1_0 = Flow[:,:2,:,:], Flow[:,2:4,:,:]
Flow_t_0 = -(1-t)*t*Flow_0_1+t*t*Flow_1_0
Flow_t_1 = (1-t)*(1-t)*Flow_0_1-t*(1-t)*Flow_1_0
Flow_t = torch.cat((Flow_t_0,Flow_t_1,x),1)
Flow_t = self.refine_flow(Flow_t)
Flow_t_0 = Flow_t_0+Flow_t[:,:2,:,:]
Flow_t_1 = Flow_t_1+Flow_t[:,2:4,:,:]
xt1 = warp(x0,Flow_t_0)
xt2 = warp(x1,Flow_t_1)
temp = torch.cat((Flow_t_0,Flow_t_1,x,xt1,xt2),1)
Mask = F.sigmoid(self.Mask(temp))
w1, w2 = (1-t)*Mask[:,0:1,:,:], t*Mask[:,1:2,:,:]
output = (w1*xt1+w2*xt2)/(w1+w2+1e-8)
return output
def forward(self, input0, input1, t=0.5):
output = self.process(input0,input1,t)
compose = torch.cat((input0, input1, output),1)
final = self.final(compose)+output
final = final.clamp(0,1)
return final