forked from openai/CLIP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix torch._C.Node attribute access (openai#372)
Attribute access with subscripting would previously work due to patching in pytorch/pytorch#82511 but this has been removed. This commit uses the fix proposed in pytorch/pytorch#82628 to define a helper method to call the appropriate access method.
- Loading branch information
Showing
8 changed files
with
383 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import matplotlib | ||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | ||
import os | ||
import skimage | ||
import IPython.display | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
import numpy as np | ||
import clip | ||
from collections import OrderedDict | ||
import torch | ||
from torchvision.datasets import CIFAR100 | ||
|
||
|
||
|
||
model, preprocess = clip.load("ViT-B/32") | ||
|
||
Compose([ | ||
Resize(size=224, max_size=None, antialias=None), | ||
CenterCrop(size=224), | ||
ToTensor(), | ||
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))]) | ||
|
||
print(clip.tokenize("thyroid cytology. Ciliated respiratory epithelial cells. These may be obtained from inadvertent sampling of the trachea during a thyroid FNA. (ThinPrep, Papanicolaou.) ")) | ||
|
||
|
||
# images in skimage to use and their textual descriptions | ||
descriptions = { | ||
"page": "a page of text about segmentation", | ||
"chelsea": "a facial photo of a tabby cat", | ||
"astronaut": "a portrait of an astronaut with the American flag", | ||
"rocket": "a rocket standing on a launchpad", | ||
"motorcycle_right": "a red motorcycle standing in a garage", | ||
"camera": "a person looking at a camera on a tripod", | ||
"horse": "a black-and-white silhouette of a horse", | ||
"coffee": "a cup of coffee on a saucer" | ||
} | ||
|
||
original_images = [] | ||
images = [] | ||
texts = [] | ||
plt.figure(figsize=(16, 5)) | ||
|
||
for filename in [filename for filename in os.listdir(skimage.data_dir) if | ||
filename.endswith(".png") or filename.endswith(".jpg")]: | ||
name = os.path.splitext(filename)[0] | ||
if name not in descriptions: | ||
continue | ||
|
||
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB") | ||
|
||
plt.subplot(2, 4, len(images) + 1) | ||
plt.imshow(image) | ||
plt.title(f"{filename}\n{descriptions[name]}") | ||
plt.xticks([]) | ||
plt.yticks([]) | ||
|
||
original_images.append(image) | ||
images.append(preprocess(image)) | ||
texts.append(descriptions[name]) | ||
|
||
plt.tight_layout() | ||
|
||
image_input = torch.tensor(np.stack(images)).cuda() | ||
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda() | ||
|
||
with torch.no_grad(): | ||
image_features = model.encode_image(image_input).float() | ||
text_features = model.encode_text(text_tokens).float() | ||
|
||
image_features /= image_features.norm(dim=-1, keepdim=True) | ||
text_features /= text_features.norm(dim=-1, keepdim=True) | ||
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T | ||
|
||
count = len(descriptions) | ||
|
||
plt.figure(figsize=(20, 14)) | ||
plt.imshow(similarity, vmin=0.1, vmax=0.3) | ||
# plt.colorbar() | ||
plt.yticks(range(count), texts, fontsize=18) | ||
plt.xticks([]) | ||
for i, image in enumerate(original_images): | ||
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower") | ||
for x in range(similarity.shape[1]): | ||
for y in range(similarity.shape[0]): | ||
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12) | ||
|
||
for side in ["left", "top", "right", "bottom"]: | ||
plt.gca().spines[side].set_visible(False) | ||
|
||
plt.xlim([-0.5, count - 0.5]) | ||
plt.ylim([count + 0.5, -2]) | ||
|
||
plt.title("Cosine similarity between text and image features", size=20) | ||
|
||
|
||
cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True) | ||
|
||
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes] | ||
text_tokens = clip.tokenize(text_descriptions).cuda() | ||
|
||
|
||
with torch.no_grad(): | ||
text_features = model.encode_text(text_tokens).float() | ||
text_features /= text_features.norm(dim=-1, keepdim=True) | ||
|
||
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) | ||
print(text_probs) | ||
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1) | ||
|
||
plt.figure(figsize=(16, 16)) | ||
|
||
for i, image in enumerate(original_images): | ||
plt.subplot(4, 4, 2 * i + 1) | ||
plt.imshow(image) | ||
plt.axis("off") | ||
|
||
plt.subplot(4, 4, 2 * i + 2) | ||
y = np.arange(top_probs.shape[-1]) | ||
plt.grid() | ||
plt.barh(y, top_probs[i]) | ||
plt.gca().invert_yaxis() | ||
plt.gca().set_axisbelow(True) | ||
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()]) | ||
plt.xlabel("probability") | ||
|
||
plt.subplots_adjust(wspace=0.5) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# -- coding: utf-8 -- | ||
# @Time : 2023/11/8 9:37 | ||
# @Time : 2023/11/9 12:13 | ||
# @Author : 王川远 | ||
# @Email : 3030764269@qq.com | ||
# @File : test_prob.py | ||
# @File : testa.py | ||
# @Software: PyCharm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# -- coding: utf-8 -- | ||
# @Time : 2023/11/9 11:55 | ||
# @Author : 王川远 | ||
# @Email : 3030764269@qq.com | ||
# @File : translate.py | ||
# @Software: PyCharm | ||
import random | ||
import requests | ||
from hashlib import md5 | ||
import os | ||
import wget | ||
import paddle | ||
from PIL import Image | ||
import HTML | ||
import ftfy | ||
import html | ||
|
||
# 百度翻译服务访问 | ||
class Translator: | ||
def __init__(self): | ||
self.appid = '20210313000725566' | ||
self.appkey = 'anWY5DNo2Ab57bgmXnqR' | ||
self.url = 'http://api.fanyi.baidu.com/api/trans/vip/translate' | ||
self.headers = {'Content-Type': 'application/x-www-form-urlencoded'} | ||
self.payload = { | ||
'appid': '20210313000725566', | ||
'from': 'zh', | ||
'to': 'en', | ||
} | ||
|
||
@staticmethod | ||
def make_md5(s, encoding='utf-8'): | ||
return md5(s.encode(encoding)).hexdigest() | ||
|
||
def translate(self, query): | ||
salt = random.randint(32768, 65536) | ||
sign = self.make_md5(self.appid + query + str(salt) + self.appkey) | ||
|
||
self.payload['salt'] = salt | ||
self.payload['sign'] = sign | ||
self.payload['q'] = query | ||
r = requests.post(self.url, params=self.payload, headers=self.headers) | ||
result = r.json()['trans_result'][0]['dst'] | ||
|
||
return result | ||
|
||
|
||
|
||
|
||
|
||
# 检索引擎 | ||
class IMSP: | ||
def __init__(self, db_file=None): | ||
self.model, self.transforms = clip.load('ViT_B_32', pretrained=True) | ||
if db_file is None: | ||
db_file = 'image_db' | ||
db_url = 'https://bj.bcebos.com/v1/ai-studio-online/775e9601019646b2a09f717789a4602f069a26302f8643418ec7c2370b895da9?responseContentDisposition=attachment%3B%20filename%3Dimage_db' | ||
if not os.path.isfile(db_file): | ||
wget.download(db_url) | ||
self.image_features, self.photo_ids = self.load_db(db_file) | ||
self.translator = Translator() | ||
|
||
@staticmethod | ||
def load_db(db_file): | ||
image_db = paddle.load(db_file) | ||
|
||
image_features = image_db['image_features'].astype('float32') | ||
image_features = paddle.to_tensor(image_features) | ||
|
||
photo_ids = image_db['photo_ids'] | ||
|
||
return image_features, photo_ids | ||
|
||
@staticmethod | ||
def get_urls(photo_ids): | ||
urls = [] | ||
for photo_id in photo_ids: | ||
url = f"https://unsplash.com/photos/{photo_id}" | ||
urls.append(url) | ||
return urls | ||
|
||
@staticmethod | ||
def is_chinese(texts): | ||
return any('\u4e00' <= char <= '\u9fff' for char in texts) | ||
|
||
# 搜索图像,topk表示检索结果个数 | ||
def im_search(self, texts, topk=5, return_urls=True): | ||
if self.is_chinese(texts): | ||
texts = self.translator.translate(texts) | ||
|
||
texts = tokenize(texts) | ||
with paddle.no_grad(): | ||
text_features = self.model.encode_text(texts) | ||
|
||
logit_scale = self.model.logit_scale.exp() | ||
logits_per_text = logit_scale * text_features @ self.image_features.t() | ||
|
||
indexs = logits_per_text.topk(topk)[1][0] | ||
photo_ids = [self.photo_ids[index] for index in indexs] | ||
|
||
if return_urls: | ||
return self.get_urls(photo_ids) | ||
else: | ||
return photo_ids | ||
|
||
|
||
|
||
|
||
def display_photo(photo_urls): | ||
for photo_url in photo_urls: | ||
photo_preview_url = photo_url+"/download?w=224" | ||
display(Image(url=photo_preview_url)) | ||
display(HTML(f'原图请点击:<a target="_blank" href="{photo_url}">Unsplash Link</a>')) | ||
|
||
|
||
|
||
|
||
|
||
# 实例化检索引擎 | ||
imsp_engine = IMSP() | ||
|
||
photo_urls = imsp_engine.im_search('公交车', topk=5) | ||
print(photo_urls) | ||
display_photo(photo_urls) | ||
|
||
photo_urls = imsp_engine.im_search('blue sky with cloud', topk=5) | ||
print(photo_urls) | ||
|
||
#显示结果结果 | ||
display_photo(photo_urls) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# -- coding: utf-8 -- | ||
# @Time : 2023/11/8 9:37 | ||
# @Author : 王川远 | ||
# @Email : 3030764269@qq.com | ||
# @File : wcy.py | ||
# @Software: PyCharm | ||
import torch | ||
import clip | ||
from PIL import Image | ||
import time | ||
|
||
|
||
t0=time.time() | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
model, preprocess = clip.load("ViT-B/32", device=device) | ||
t1=time.time() | ||
print(t1-t0) | ||
image = preprocess(Image.open("4.png")).unsqueeze(0).to(device) | ||
t2=time.time() | ||
print(t2-t1) | ||
text = clip.tokenize(["a painting", "medical cells", "thyroid cytology","thyroid cytology. Ciliated respiratory epithelial cells. These may be obtained from inadvertent sampling of the trachea during a thyroid FNA. (ThinPrep, Papanicolaou.) ","thyroid cytology.Acute thyroiditis. Marked acute inflammation and debris are seen, but follicular cells and colloid are absent. (Smear, Papanicolaou.) "]).to(device) | ||
t3=time.time() | ||
print(t3-t2) | ||
with torch.no_grad(): | ||
image_features = model.encode_image(image) | ||
text_features = model.encode_text(text) | ||
|
||
logits_per_image, logits_per_text = model(image, text) | ||
print(logits_per_image) | ||
probs = logits_per_image.softmax(dim=-1).cpu().numpy() | ||
|
||
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# -- coding: utf-8 -- | ||
# @Time : 2023/11/9 10:25 | ||
# @Author : 王川远 | ||
# @Email : 3030764269@qq.com | ||
# @File : wcy2.py | ||
# @Software: PyCharm | ||
import os | ||
import clip | ||
import torch | ||
from torchvision.datasets import CIFAR100 | ||
from PIL import Image | ||
|
||
# Load the model | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
model, preprocess = clip.load('ViT-B/32', device) | ||
|
||
# Download the dataset | ||
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) | ||
|
||
# Prepare the inputs | ||
image=Image.open("4.png") | ||
image_input = preprocess(image).unsqueeze(0).to(device) | ||
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) | ||
|
||
# Calculate features | ||
with torch.no_grad(): | ||
image_features = model.encode_image(image_input) | ||
text_features = model.encode_text(text_inputs) | ||
|
||
# Pick the top 5 most similar labels for the image | ||
image_features /= image_features.norm(dim=-1, keepdim=True) | ||
text_features /= text_features.norm(dim=-1, keepdim=True) | ||
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | ||
values, indices = similarity[0].topk(5) | ||
|
||
# Print the result | ||
print("\nTop predictions:\n") | ||
for value, index in zip(values, indices): | ||
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%") |
Oops, something went wrong.