-
Notifications
You must be signed in to change notification settings - Fork 100
/
util.py
93 lines (79 loc) · 3.05 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import division
import torch
from torch.utils.serialization import load_lua
import torchvision.transforms as transforms
import numpy as np
import argparse
import time
import os
from PIL import Image
from modelsNIPS import decoder1,decoder2,decoder3,decoder4,decoder5
from modelsNIPS import encoder1,encoder2,encoder3,encoder4,encoder5
import torch.nn as nn
class WCT(nn.Module):
def __init__(self,args):
super(WCT, self).__init__()
# load pre-trained network
vgg1 = load_lua(args.vgg1)
decoder1_torch = load_lua(args.decoder1)
vgg2 = load_lua(args.vgg2)
decoder2_torch = load_lua(args.decoder2)
vgg3 = load_lua(args.vgg3)
decoder3_torch = load_lua(args.decoder3)
vgg4 = load_lua(args.vgg4)
decoder4_torch = load_lua(args.decoder4)
vgg5 = load_lua(args.vgg5)
decoder5_torch = load_lua(args.decoder5)
self.e1 = encoder1(vgg1)
self.d1 = decoder1(decoder1_torch)
self.e2 = encoder2(vgg2)
self.d2 = decoder2(decoder2_torch)
self.e3 = encoder3(vgg3)
self.d3 = decoder3(decoder3_torch)
self.e4 = encoder4(vgg4)
self.d4 = decoder4(decoder4_torch)
self.e5 = encoder5(vgg5)
self.d5 = decoder5(decoder5_torch)
def whiten_and_color(self,cF,sF):
cFSize = cF.size()
c_mean = torch.mean(cF,1) # c x (h x w)
c_mean = c_mean.unsqueeze(1).expand_as(cF)
cF = cF - c_mean
contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double()
c_u,c_e,c_v = torch.svd(contentConv,some=False)
k_c = cFSize[0]
for i in range(cFSize[0]):
if c_e[i] < 0.00001:
k_c = i
break
sFSize = sF.size()
s_mean = torch.mean(sF,1)
sF = sF - s_mean.unsqueeze(1).expand_as(sF)
styleConv = torch.mm(sF,sF.t()).div(sFSize[1]-1)
s_u,s_e,s_v = torch.svd(styleConv,some=False)
k_s = sFSize[0]
for i in range(sFSize[0]):
if s_e[i] < 0.00001:
k_s = i
break
c_d = (c_e[0:k_c]).pow(-0.5)
step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
step2 = torch.mm(step1,(c_v[:,0:k_c].t()))
whiten_cF = torch.mm(step2,cF)
s_d = (s_e[0:k_s]).pow(0.5)
targetFeature = torch.mm(torch.mm(torch.mm(s_v[:,0:k_s],torch.diag(s_d)),(s_v[:,0:k_s].t())),whiten_cF)
targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
return targetFeature
def transform(self,cF,sF,csF,alpha):
cF = cF.double()
sF = sF.double()
C,W,H = cF.size(0),cF.size(1),cF.size(2)
_,W1,H1 = sF.size(0),sF.size(1),sF.size(2)
cFView = cF.view(C,-1)
sFView = sF.view(C,-1)
targetFeature = self.whiten_and_color(cFView,sFView)
targetFeature = targetFeature.view_as(cF)
ccsF = alpha * targetFeature + (1.0 - alpha) * cF
ccsF = ccsF.float().unsqueeze(0)
csF.data.resize_(ccsF.size()).copy_(ccsF)
return csF