-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathgenerate.py
130 lines (82 loc) · 3.46 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# Copyright (c) 2021 lucidrains
# This file has been modified by Graphcore
import argparse
from pathlib import Path
from tqdm import tqdm
# torch
import torch
import poptorch
from einops import repeat
# vision imports
from PIL import Image
from torchvision.utils import make_grid, save_image
# dalle related classes and utils
from models import VQGanVAE, WrappedDALLE
from models.tokenizer import SimpleTokenizer, YttmTokenizer
# argument parsing
parser = argparse.ArgumentParser()
parser.add_argument("--dalle_path", type=str, required=True, help="path to your trained DALL-E")
parser.add_argument(
"--vqgan_model_path",
type=str,
default=None,
help="path to your trained VQGAN weights. This should be a .ckpt file.",
)
parser.add_argument(
"--vqgan_config_path",
type=str,
default=None,
help="path to your trained VQGAN config. This should be a .yaml file.",
)
parser.add_argument("--text", type=str, required=True, help="your text prompt")
parser.add_argument("--num_images", type=int, default=128, required=False, help="number of images")
parser.add_argument("--batch_size", type=int, default=4, required=False, help="batch size")
parser.add_argument("--top_k", type=float, default=0.99, required=False, help="top k filter threshold")
parser.add_argument("--outputs_dir", type=str, default="./outputs", required=False, help="output directory")
parser.add_argument("--bpe_path", type=str, help="path to your yttm BPE json file")
parser.add_argument("--gentxt", dest="gentxt", action="store_true")
args = parser.parse_args()
# helper fns
def exists(val):
return val is not None
# tokenizer
if exists(args.bpe_path):
klass = YttmTokenizer
tokenizer = klass(args.bpe_path)
else:
tokenizer = SimpleTokenizer()
# load DALL-E
dalle_path = Path(args.dalle_path)
assert dalle_path.exists(), "trained DALL-E must exist"
load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop("hparams"), load_obj.pop("vae_params"), load_obj.pop("weights")
dalle_params.pop("vae", None) # cleanup later
dalle_params["fp16"] = False
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
dalle = WrappedDALLE(vae=vae, **dalle_params).eval()
dalle.load_state_dict(weights)
# generate images
image_size = vae.image_size
texts = args.text.split("|")
for j, text in tqdm(enumerate(texts)):
if args.gentxt:
text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres=args.top_k)
text = gen_texts[0]
else:
text_tokens = tokenizer.tokenize([text], dalle.model.text_seq_len, truncate_text=True)
text_tokens = repeat(text_tokens, "() n -> b n", b=args.num_images)
outputs = []
for text_chunk in tqdm(text_tokens.split(args.batch_size), desc=f"generating images for - {text}"):
output = dalle.generate_images(text_chunk, filter_thres=args.top_k)
outputs.append(output)
outputs = torch.cat(outputs)
# save all images
file_name = text
outputs_dir = Path(args.outputs_dir) / file_name.replace(" ", "_")[:(100)] # Filename length is limited to 100
outputs_dir.mkdir(parents=True, exist_ok=True)
for i, image in tqdm(enumerate(outputs), desc="saving images"):
save_image(image, outputs_dir / f"{i}.jpg", normalize=True)
with open(outputs_dir / "caption.txt", "w") as f:
f.write(file_name)
print(f'created {args.num_images} images at "{str(outputs_dir)}"')