Skip to content

Commit

Permalink
Merge pull request #1490 from vespa-engine/andreer/similarity-map-fro…
Browse files Browse the repository at this point in the history
…m-vespa

compute similarity map in vespa
  • Loading branch information
thomasht86 authored Oct 11, 2024
2 parents e46b2f2 + 826e2ac commit b1dfeb8
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 31 deletions.
114 changes: 84 additions & 30 deletions visual-retrieval-colpali/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from io import BytesIO
from typing import Union, Tuple, List, Dict, Any
import matplotlib
import matplotlib.cm as cm
import re
import io

import json
import time

from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device
from einops import rearrange
Expand Down Expand Up @@ -114,6 +118,7 @@ def gen_similarity_maps(
query_embs: torch.Tensor,
token_idx_map: dict,
images: List[Union[Path, str]],
vespa_sim_maps: List[str],
) -> List[Dict[str, str]]:
"""
Generate similarity maps for the given images and query, and return base64-encoded blended images.
Expand All @@ -131,8 +136,8 @@ def gen_similarity_maps(
Returns:
List[Dict[str, str]]: A list where each item is a dictionary mapping tokens to base64-encoded blended images.
"""
import numpy as np
import matplotlib.cm as cm

start = time.perf_counter()

# Prepare the colormap once to avoid recomputation
colormap = cm.get_cmap("viridis")
Expand All @@ -158,39 +163,74 @@ def gen_similarity_maps(
original_sizes.append(img_pil.size) # (width, height)
processed_images.append(img_pil)

# Preprocess inputs
input_image_processed = processor.process_images(processed_images).to(device)

# Forward passes
with torch.no_grad():
output_image = model.forward(**input_image_processed)

# Remove the special tokens from the output
output_image = output_image[:, : processor.image_seq_length, :]
# If similarity maps are provided, use them instead of computing them
if vespa_sim_maps:
print("Using provided similarity maps")
# A sim map looks like this:
# "similarities": [
# {
# "address": {
# "patch": "0",
# "querytoken": "0"
# },
# "value": 1.2599412202835083
# },
# ... and so on.
# Now turn these into a tensor of same shape as previous similarity map
vespa_sim_map_tensor = torch.zeros(
(len(vespa_sim_maps), query_embs.size(dim=1), vit_config.n_patch_per_dim, vit_config.n_patch_per_dim)
)
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
for cell in vespa_sim_map["similarities"]["cells"]:
patch = int(cell["address"]["patch"])
if patch >= processor.image_seq_length:
continue
query_token = int(cell["address"]["querytoken"])
value = cell["value"]
vespa_sim_map_tensor[idx, int(query_token), int(patch) // vit_config.n_patch_per_dim, int(patch) % vit_config.n_patch_per_dim] = value

# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(vespa_sim_map_tensor)
else:
# Preprocess inputs
print("Computing similarity maps")
start2 = time.perf_counter()
input_image_processed = processor.process_images(processed_images).to(device)

# Forward passes
with torch.no_grad():
output_image = model.forward(**input_image_processed)

# Remove the special tokens from the output
output_image = output_image[:, : processor.image_seq_length, :]

# Rearrange the output image tensor to represent the 2D grid of patches
output_image = rearrange(
output_image,
"b (h w) c -> b h w c",
h=vit_config.n_patch_per_dim,
w=vit_config.n_patch_per_dim,
)

# Rearrange the output image tensor to represent the 2D grid of patches
output_image = rearrange(
output_image,
"b (h w) c -> b h w c",
h=vit_config.n_patch_per_dim,
w=vit_config.n_patch_per_dim,
)
# Ensure query_embs has batch dimension
if query_embs.dim() == 2:
query_embs = query_embs.unsqueeze(0).to(device)
else:
query_embs = query_embs.to(device)

# Ensure query_embs has batch dimension
if query_embs.dim() == 2:
query_embs = query_embs.unsqueeze(0).to(device)
else:
query_embs = query_embs.to(device)
# Compute the similarity map
similarity_map = torch.einsum(
"bnk,bhwk->bnhw", query_embs, output_image
) # Shape: (batch_size, query_tokens, h, w)

# Compute the similarity map
similarity_map = torch.einsum(
"bnk,bhwk->bnhw", query_embs, output_image
) # Shape: (batch_size, query_tokens, h, w)
end2 = time.perf_counter()
print(f"Similarity map computation took: {end2 - start2} s")

# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map)
# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map)

# Collect the blended images
start3 = time.perf_counter()
results = []
for idx, img in enumerate(original_images):
original_size = original_sizes[idx] # (width, height)
Expand Down Expand Up @@ -248,6 +288,9 @@ def gen_similarity_maps(
# Store the base64-encoded image
result_per_image[token] = blended_img_base64
results.append(result_per_image)
end3 = time.perf_counter()
print(f"Collecting blended images took: {end3 - start3} s")
print(f"Total heatmap generation took: {end3 - start} s")
return results


Expand Down Expand Up @@ -285,9 +328,11 @@ async def query_vespa_default(
) -> dict:
async with app.asyncio(connections=1, total_timeout=120) as session:
query_embedding = format_q_embs(q_emb)

start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": "select id,title,url,full_image,page_number,snippet,text from pdf_page where userQuery();",
"yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
"ranking": "default",
"query": query,
"timeout": timeout,
Expand All @@ -298,6 +343,9 @@ async def query_vespa_default(
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s")
open("response.json", "w").write(json.dumps(response.json))
return format_query_results(query, response)


Expand Down Expand Up @@ -447,10 +495,14 @@ def add_sim_maps_to_result(
) -> Dict[str, Any]:
vit_config = load_vit_config(model)
imgs: List[str] = []
vespa_sim_maps: List[str] = []
for single_result in result["root"]["children"]:
img = single_result["fields"]["full_image"]
if img:
imgs.append(img)
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
if vespa_sim_map:
vespa_sim_maps.append(vespa_sim_map)
sim_map_imgs = gen_similarity_maps(
model=model,
processor=processor,
Expand All @@ -460,6 +512,7 @@ def add_sim_maps_to_result(
query_embs=q_embs,
token_idx_map=token_to_idx,
images=imgs,
vespa_sim_maps=vespa_sim_maps
)
for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
for token, sim_mapb64 in sim_map_dict.items():
Expand Down Expand Up @@ -491,6 +544,7 @@ def add_sim_maps_to_result(
query_embs=q_embs,
token_idx_map=token_to_idx,
images=[image_filepath],
vespa_sim_maps=None,
)
for fig_token in figs_images:
for token, (fig, ax) in fig_token.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ schema pdf_page {

}
}
function similarities() {
expression {
sum(
query(qt) * unpack_bits(attribute(embedding)), v
)
}
}
function bm25_score() {
expression {
bm25(title) + bm25(text)
Expand All @@ -108,6 +115,7 @@ schema pdf_page {
max_sim
}
}
summary-features: similarities
}
rank-profile retrieval-and-rerank {
inputs {
Expand Down
5 changes: 4 additions & 1 deletion visual-retrieval-colpali/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fasthtml.common import *
from shad4fast import *
from vespa.application import Vespa
import time

from backend.colpali import (
get_result_from_query,
Expand Down Expand Up @@ -103,6 +104,7 @@ async def get(request, query: str, nn: bool = True):
processor = manager.processor
q_embs, token_to_idx = get_query_embeddings_and_token_map(processor, model, query)

start = time.perf_counter()
# Fetch real search results from Vespa
result = await get_result_from_query(
app=vespa_app,
Expand All @@ -113,13 +115,14 @@ async def get(request, query: str, nn: bool = True):
token_to_idx=token_to_idx,
ranking=ranking_value,
)
end = time.perf_counter()
print(f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds")
# Start generating the similarity map in the background
asyncio.create_task(
generate_similarity_map(
model, processor, query, q_embs, token_to_idx, result, query_id
)
)
print("Search results fetched")
search_results = (
result["root"]["children"]
if "root" in result and "children" in result["root"]
Expand Down

0 comments on commit b1dfeb8

Please sign in to comment.