forked from nkolkin13/NeuralNeighborStyleTransfer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
103 lines (90 loc) · 3.17 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import tempfile
import time
from imageio import imwrite
import torch
import numpy as np
from cog import BasePredictor, Path, Input
# Internal Project Imports
from pretrained.vgg import Vgg16Pretrained
from utils import misc as misc
from utils.misc import load_path_for_pytorch
from utils.stylize import produce_stylization
class Predictor(BasePredictor):
def setup(self):
# Define feature extractor
cnn = misc.to_device(Vgg16Pretrained())
self.phi = lambda x, y, z: cnn.forward(x, inds=y, concat=z)
def predict(
self,
content: Path = Input(description="Content image."),
style: Path = Input(description="Style image."),
colorize: bool = Input(
default=True, description="Whether use color correction in the output."
),
high_res: bool = Input(
default=False,
description="Whether output high resolution image (1024 instead if 512).",
),
alpha: float = Input(
default=0.75,
ge=0.0,
le=1.0,
description="alpha=1.0 corresponds to maximum content preservation, alpha=0.0 is maximum stylization.",
),
content_loss: bool = Input(
default=False, description="Whether use experimental content loss."
),
) -> Path:
max_scls = 4
sz = 512
if high_res:
max_scls = 5
sz = 1024
flip_aug = True
misc.USE_GPU = True
content_weight = 1.0 - alpha
# Error checking for arguments
# error checking for paths deferred to imageio
assert (0.0 <= content_weight) and (
content_weight <= 1.0
), "alpha must be between 0 and 1"
assert torch.cuda.is_available() or (
not misc.USE_GPU
), "attempted to use gpu when unavailable"
# Load images
content_im_orig = misc.to_device(
load_path_for_pytorch(str(content), target_size=sz)
).unsqueeze(0)
style_im_orig = misc.to_device(
load_path_for_pytorch(str(style), target_size=sz)
).unsqueeze(0)
# Run Style Transfer
torch.cuda.synchronize()
start_time = time.time()
output = produce_stylization(
content_im_orig,
style_im_orig,
self.phi,
max_iter=200,
lr=2e-3,
content_weight=content_weight,
max_scls=max_scls,
flip_aug=flip_aug,
content_loss=content_loss,
dont_colorize=not colorize,
)
torch.cuda.synchronize()
print("Done! total time: {}".format(time.time() - start_time))
# Convert from pyTorch to numpy, clip to valid range
new_im_out = np.clip(
output[0].permute(1, 2, 0).detach().cpu().numpy(), 0.0, 1.0
)
# Save stylized output
save_im = (new_im_out * 255).astype(np.uint8)
out_path = Path(tempfile.mkdtemp()) / "output.png"
imwrite("ooo.png", save_im)
imwrite(str(out_path), save_im)
# Free gpu memory in case something else needs it later
if misc.USE_GPU:
torch.cuda.empty_cache()
return out_path