-
Notifications
You must be signed in to change notification settings - Fork 1
/
build_vocab_fashionpedia.py
89 lines (78 loc) · 2.56 KB
/
build_vocab_fashionpedia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from collections import Counter
from glob import glob
import nltk
from tqdm import tqdm
from lib.utils.directory import read_json, write_json
def get_tokens(caption):
caption = (
caption.replace("-", " ")
.replace(".", "")
.replace("(", " ")
.replace(")", " ")
.replace(",", " ")
)
toks = nltk.tokenize.word_tokenize(caption.lower())
return [vocab[word] for word in toks]
# Build vocab
cap_files = glob("train_data/fashionpedia/*_triplets_*.json")
counter = Counter()
for cap_file in cap_files:
cap_data = read_json(cap_file)
for data in cap_data:
for caption in data["captions"]:
caption = (
caption.replace("-", " ")
.replace(".", "")
.replace("(", " ")
.replace(")", " ")
.replace(",", " ")
)
toks = nltk.tokenize.word_tokenize(caption.lower())
counter.update(toks)
print("Total Words:", len(counter))
vocab = dict(zip(counter.keys(), range(5, len(counter) + 6))) # remain for other tokens
vocab["<NULL>"] = 0
vocab["<UNK>"] = 1
vocab["<START>"] = 2
vocab["<END>"] = 3
vocab["<LINK>"] = 4
# Save vocab
write_json(vocab, "train_data/fashionpedia/vocab.json")
# Build cap file
for cap_file in cap_files:
# Load data
cap_data = read_json(cap_file)
# Save name
sn = os.path.basename(cap_file).split("_")
sn = "_".join(sn[:2] + ["dict"] + sn[2:])
save_file = os.path.join(os.path.dirname(cap_file), sn)
# Process
for data in tqdm(cap_data):
captions = data["captions"]
data["wv"] = (
get_tokens(captions[0]) + [vocab["<LINK>"]] + get_tokens(captions[1])
)
write_json(cap_data, save_file)
# # Build vocab
# # cap_files = glob("train_data/fashionpedia/comp_miner*.json")
# cap_files = glob("train_data/fashionpedia/hybrid_triplets_test_turn3.json")
# vocab = read_json("train_data/fashionpedia/vocab.json")
# # Build cap file
# for cap_file in cap_files:
# print(cap_file)
# # Load data
# cap_data = read_json(cap_file)
# # Save name
# sn = os.path.basename(cap_file).split("_")
# sn = "_".join(sn[:2] + ["dict"] + sn[2:])
# save_file = os.path.join(os.path.dirname(cap_file), sn)
# # Process
# for data in tqdm(cap_data):
# captions = data["captions"]
# data["wv"] = []
# for caption in captions:
# data["wv"].append(
# get_tokens(caption[0]) + [vocab["<LINK>"]] + get_tokens(caption[1])
# )
# write_json(cap_data, save_file)