Skip to content

Commit

Permalink
Fix torch._C.Node attribute access (openai#372)
Browse files Browse the repository at this point in the history
Attribute access with subscripting would previously work
due to patching in pytorch/pytorch#82511
but this has been removed.

This commit uses the fix proposed in pytorch/pytorch#82628
to define a helper method to call the appropriate access method.
  • Loading branch information
jamt9000 authored and grandcyw committed Nov 8, 2023
1 parent a9b1bf5 commit 8478670
Show file tree
Hide file tree
Showing 9 changed files with 1,237 additions and 882 deletions.
14 changes: 11 additions & 3 deletions clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

def _node_get(node: torch._C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type.
From https://github.com/pytorch/pytorch/pull/82628
"""
sel = node.kindOf(key)
return getattr(node, sel)(key)

def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
Expand All @@ -156,7 +164,7 @@ def patch_device(module):

for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
node.copyAttributes(device_node)

model.apply(patch_device)
Expand All @@ -182,7 +190,7 @@ def patch_float(module):
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
if _node_get(inputs[i].node(), "value") == 5:
inputs[i].node().copyAttributes(float_node)

model.apply(patch_float)
Expand All @@ -194,7 +202,7 @@ def patch_float(module):
return model, _transform(model.input_resolution.item())


def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Expand Down
19 changes: 19 additions & 0 deletions clip/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
import torch
import clip
from pkg_resources import packaging

print("Torch version:", torch.__version__)


clip.available_models()
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
148 changes: 148 additions & 0 deletions image_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import matplotlib
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import clip
from collections import OrderedDict
import torch
from torchvision.datasets import CIFAR100

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

Compose([
Resize(size=224, max_size=None, antialias=None),
CenterCrop(size=224),
ToTensor(),
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))])

# images in skimage to use and their textual descriptions

descriptions = {}
original_images = []
images = []
texts = []
data_dir = 'C:/keywords'
txt_path = 'C:/keywords/newDescriptions.txt'
labels_path = 'C:/filter/newLabels.txt'
max_token_size = 77
imshow_num = 4

with open(txt_path, 'r', encoding='utf-8') as f:
content = f.read()
f.close()
i = 0
while i < len(content):
if content[i:i + 5] == "<key>":
i += 6
fig = ""
while content[i] != '<':
fig += content[i]
i += 1
i += 15
des = ""
while content[i] != '<':
des += content[i]
i += 1
fig = fig[:-1]
des = des[:-1]
descriptions[fig] = des
else:
i += 1
print(f'{descriptions}')
for filename in [filename for filename in os.listdir(data_dir) if
filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue

image = Image.open(os.path.join(data_dir, filename)).convert("RGB")
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])

image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize([desc for desc in texts]).cuda()

with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

count = imshow_num

# print(f'{len(original_images)}')
plt.figure(figsize=(10, 7))
plt.imshow(similarity, vmin=0, vmax=1)
# plt.colorbar()
plt.yticks(range(count), texts[0:imshow_num], fontsize=9)
plt.xticks([])
for i, image in enumerate(original_images):
if i > imshow_num-1:
break
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
if x>imshow_num-1:
continue
for y in range(similarity.shape[0]):
if y>imshow_num-1:
continue
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=6)

for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=10)

labels = []
with open(labels_path, 'r', encoding='utf-8') as f1:
line = f1.readline()
while line:
if len(line) < max_token_size:
labels.append(line)
line = f1.readline()
f1.close()
# print(f'{labels}')
text_descriptions = [f'{label}' for label in labels]
text_tokens = clip.tokenize(text_descriptions).cuda()

with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)

text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
if i > 3:
break
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")

plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [labels[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")

plt.subplots_adjust(wspace=0.5)
plt.show()
6 changes: 6 additions & 0 deletions logs/Untitled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 8478670

Please sign in to comment.