forked from facebookresearch/ImageBind
-
Notifications
You must be signed in to change notification settings - Fork 16
/
example.py
93 lines (78 loc) · 3.38 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import logging
import torch
import data
from models import imagebind_model
from models.imagebind_model import ModalityType, load_module
from models import lora as LoRA
logging.basicConfig(level=logging.INFO, force=True)
lora = True
linear_probing = False
device = "cpu" # "cuda:0" if torch.cuda.is_available() else "cpu"
load_head_post_proc_finetuned = True
assert not (linear_probing and lora), \
"Linear probing is a subset of LoRA training procedure for ImageBind. " \
"Cannot set both linear_probing=True and lora=True. "
if lora and not load_head_post_proc_finetuned:
# Hack: adjust lora_factor to the `max batch size used during training / temperature` to compensate missing norm
lora_factor = 12 / 0.07
else:
# This assumes proper loading of all params but results in shift from original dist in case of LoRA
lora_factor = 1
text_list=["bird",
"car",
"dog3",
"dog5",
"dog8",
"grey_sloth_plushie"]
image_paths=[".assets/bird_image.jpg",
".assets/car_image.jpg",
".assets/dog3.jpg",
".assets/dog5.jpg",
".assets/dog8.jpg",
".assets/grey_sloth_plushie.jpg"]
audio_paths=[".assets/bird_audio.wav",
".assets/car_audio.wav",
".assets/dog_audio.wav"]
# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
if lora:
model.modality_trunks.update(
LoRA.apply_lora_modality_trunks(model.modality_trunks, rank=4,
# layer_idxs={ModalityType.TEXT: [0, 1, 2, 3, 4, 5, 6, 7, 8],
# ModalityType.VISION: [0, 1, 2, 3, 4, 5, 6, 7, 8]},
modality_names=[ModalityType.TEXT, ModalityType.VISION]))
# Load LoRA params if found
LoRA.load_lora_modality_trunks(model.modality_trunks,
checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last")
if load_head_post_proc_finetuned:
# Load postprocessors & heads
load_module(model.modality_postprocessors, module_name="postprocessors",
checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last")
load_module(model.modality_heads, module_name="heads",
checkpoint_dir=".checkpoints/lora/550_epochs_lora", postfix="_dreambooth_last")
elif linear_probing:
# Load heads
load_module(model.modality_heads, module_name="heads",
checkpoint_dir="./.checkpoints/lora/500_epochs_lp", postfix="_dreambooth_last")
model.eval()
model.to(device)
# Load data
inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device, to_tensor=True),
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
print(
"Vision x Text: ",
torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T * (lora_factor if lora else 1), dim=-1),
)
print(
"Audio x Text: ",
torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T * (lora_factor if lora else 1), dim=-1),
)
print(
"Vision x Audio: ",
torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
)