Skip to content

Commit

Permalink
Fix torch._C.Node attribute access (openai#372)
Browse files Browse the repository at this point in the history
Attribute access with subscripting would previously work
due to patching in pytorch/pytorch#82511
but this has been removed.

This commit uses the fix proposed in pytorch/pytorch#82628
to define a helper method to call the appropriate access method.
  • Loading branch information
grandcyw committed Nov 10, 2023
1 parent 8478670 commit f6e4bbf
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 3 deletions.
Binary file added 4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
128 changes: 128 additions & 0 deletions split2_dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import matplotlib
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import clip
from collections import OrderedDict
import torch
from torchvision.datasets import CIFAR100



model, preprocess = clip.load("ViT-B/32")

Compose([
Resize(size=224, max_size=None, antialias=None),
CenterCrop(size=224),
ToTensor(),
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))])

print(clip.tokenize("thyroid cytology. Ciliated respiratory epithelial cells. These may be obtained from inadvertent sampling of the trachea during a thyroid FNA. (ThinPrep, Papanicolaou.) "))


# images in skimage to use and their textual descriptions
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}

original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if
filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue

image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")

plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}\n{descriptions[name]}")
plt.xticks([])
plt.yticks([])

original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])

plt.tight_layout()

image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()

with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

count = len(descriptions)

plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=20)


cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)

text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()


with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)

text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
print(text_probs)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")

plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")

plt.subplots_adjust(wspace=0.5)
plt.show()
1 change: 0 additions & 1 deletion test11.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size
Expand Down
4 changes: 2 additions & 2 deletions test_prob.py → testa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -- coding: utf-8 --
# @Time : 2023/11/8 9:37
# @Time : 2023/11/9 12:13
# @Author : 王川远
# @Email : 3030764269@qq.com
# @File : test_prob.py
# @File : testa.py
# @Software: PyCharm
130 changes: 130 additions & 0 deletions translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -- coding: utf-8 --
# @Time : 2023/11/9 11:55
# @Author : 王川远
# @Email : 3030764269@qq.com
# @File : translate.py
# @Software: PyCharm
import random
import requests
from hashlib import md5
import os
import wget
import paddle
from PIL import Image
import HTML
import ftfy
import html

# 百度翻译服务访问
class Translator:
def __init__(self):
self.appid = '20210313000725566'
self.appkey = 'anWY5DNo2Ab57bgmXnqR'
self.url = 'http://api.fanyi.baidu.com/api/trans/vip/translate'
self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
self.payload = {
'appid': '20210313000725566',
'from': 'zh',
'to': 'en',
}

@staticmethod
def make_md5(s, encoding='utf-8'):
return md5(s.encode(encoding)).hexdigest()

def translate(self, query):
salt = random.randint(32768, 65536)
sign = self.make_md5(self.appid + query + str(salt) + self.appkey)

self.payload['salt'] = salt
self.payload['sign'] = sign
self.payload['q'] = query
r = requests.post(self.url, params=self.payload, headers=self.headers)
result = r.json()['trans_result'][0]['dst']

return result





# 检索引擎
class IMSP:
def __init__(self, db_file=None):
self.model, self.transforms = clip.load('ViT_B_32', pretrained=True)
if db_file is None:
db_file = 'image_db'
db_url = 'https://bj.bcebos.com/v1/ai-studio-online/775e9601019646b2a09f717789a4602f069a26302f8643418ec7c2370b895da9?responseContentDisposition=attachment%3B%20filename%3Dimage_db'
if not os.path.isfile(db_file):
wget.download(db_url)
self.image_features, self.photo_ids = self.load_db(db_file)
self.translator = Translator()

@staticmethod
def load_db(db_file):
image_db = paddle.load(db_file)

image_features = image_db['image_features'].astype('float32')
image_features = paddle.to_tensor(image_features)

photo_ids = image_db['photo_ids']

return image_features, photo_ids

@staticmethod
def get_urls(photo_ids):
urls = []
for photo_id in photo_ids:
url = f"https://unsplash.com/photos/{photo_id}"
urls.append(url)
return urls

@staticmethod
def is_chinese(texts):
return any('\u4e00' <= char <= '\u9fff' for char in texts)

# 搜索图像,topk表示检索结果个数
def im_search(self, texts, topk=5, return_urls=True):
if self.is_chinese(texts):
texts = self.translator.translate(texts)

texts = tokenize(texts)
with paddle.no_grad():
text_features = self.model.encode_text(texts)

logit_scale = self.model.logit_scale.exp()
logits_per_text = logit_scale * text_features @ self.image_features.t()

indexs = logits_per_text.topk(topk)[1][0]
photo_ids = [self.photo_ids[index] for index in indexs]

if return_urls:
return self.get_urls(photo_ids)
else:
return photo_ids




def display_photo(photo_urls):
for photo_url in photo_urls:
photo_preview_url = photo_url+"/download?w=224"
display(Image(url=photo_preview_url))
display(HTML(f'原图请点击:<a target="_blank" href="{photo_url}">Unsplash Link</a>'))





# 实例化检索引擎
imsp_engine = IMSP()

photo_urls = imsp_engine.im_search('公交车', topk=5)
print(photo_urls)
display_photo(photo_urls)

photo_urls = imsp_engine.im_search('blue sky with cloud', topk=5)
print(photo_urls)

#显示结果结果
display_photo(photo_urls)
32 changes: 32 additions & 0 deletions wcy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -- coding: utf-8 --
# @Time : 2023/11/8 9:37
# @Author : 王川远
# @Email : 3030764269@qq.com
# @File : wcy.py
# @Software: PyCharm
import torch
import clip
from PIL import Image
import time


t0=time.time()
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
t1=time.time()
print(t1-t0)
image = preprocess(Image.open("4.png")).unsqueeze(0).to(device)
t2=time.time()
print(t2-t1)
text = clip.tokenize(["a painting", "medical cells", "thyroid cytology","thyroid cytology. Ciliated respiratory epithelial cells. These may be obtained from inadvertent sampling of the trachea during a thyroid FNA. (ThinPrep, Papanicolaou.) ","thyroid cytology.Acute thyroiditis. Marked acute inflammation and debris are seen, but follicular cells and colloid are absent. (Smear, Papanicolaou.) "]).to(device)
t3=time.time()
print(t3-t2)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)

logits_per_image, logits_per_text = model(image, text)
print(logits_per_image)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
39 changes: 39 additions & 0 deletions wcy2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -- coding: utf-8 --
# @Time : 2023/11/9 10:25
# @Author : 王川远
# @Email : 3030764269@qq.com
# @File : wcy2.py
# @Software: PyCharm
import os
import clip
import torch
from torchvision.datasets import CIFAR100
from PIL import Image

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image=Image.open("4.png")
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
Loading

0 comments on commit f6e4bbf

Please sign in to comment.