-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgen_corpus_embeddings.py
63 lines (51 loc) · 2.57 KB
/
gen_corpus_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import pandas as pd
import argparse
import torch
from tqdm import tqdm
import clip
from utils import nonpretrained_params
def encode_texts(imps, model, device):
trimmed_impressions = imps
with torch.no_grad():
imp_toks = clip.tokenize(trimmed_impressions, context_length=model.context_length).to(device)
embeddings = model.encode_text(imp_toks)
embeddings /= embeddings.norm(dim=-1, keepdim=True)
return embeddings
def main(args):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if args.clip_pretrained: # clip model is pretrained on chest X-rays, uses different architecture
model, _ = clip.load("ViT-B/32", device=device, jit=False)
print("Loaded in pretrained model.")
else:
model = clip.CLIP(**nonpretrained_params)
print("Loaded in clip model.")
model.load_state_dict(torch.load(args.clip_model_path, map_location=device))
model = model.to(device)
impressions = pd.read_csv(args.data_path)["report"]
impressions_size = impressions.shape[0]
bs = args.batch_size
num_batches = impressions_size//bs
tensors = []
for i in tqdm(range(num_batches)):
batch = impressions[bs*i:bs*i+bs]
_weights = encode_texts(batch, model, device)
tensors.append(_weights)
_weights = encode_texts(impressions[bs*num_batches:], model, device)
tensors.append(_weights)
clip_embeddings = torch.cat(tensors)
print(impressions.shape, clip_embeddings.shape)
out_data = (impressions, clip_embeddings)
if not os.path.exists('corpus_embeddings'):
os.makedirs('corpus_embeddings')
out_path = 'corpus_embeddings/' + args.out
torch.save(out_data, out_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate clip embeddings for a training corpus (either sentence level or report level')
parser.add_argument('--clip_model_path', type=str, required=True, help='name of clip model state dictionary for generating embeddings')
parser.add_argument('--clip_pretrained', action='store_true', help='Whether clip model was first pre-trained on natural images')
parser.add_argument('--data_path', type=str, required=True, help='path of csv file containing training corpus (either sentence level or report level)')
parser.add_argument('--out', type=str, required=True, help='name for saved corpus embeddings (include .pt extension)')
parser.add_argument('--batch_size', type=int, default=2000, help='Batch size for generating clip embeddings')
args = parser.parse_args()
main(args)