Skip to content

Commit

Permalink
Merge pull request #1491 from vespa-engine/thomasht86/sim-map-blendin…
Browse files Browse the repository at this point in the history
…g-gpu

(colpalidemo) generator for sim maps
  • Loading branch information
ldalves authored Oct 11, 2024
2 parents b1dfeb8 + 63b3469 commit 579af68
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 32 deletions.
63 changes: 42 additions & 21 deletions visual-retrieval-colpali/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from PIL import Image
import numpy as np
from typing import cast
from typing import cast, Generator
from pathlib import Path
import base64
from io import BytesIO
Expand Down Expand Up @@ -119,7 +119,7 @@ def gen_similarity_maps(
token_idx_map: dict,
images: List[Union[Path, str]],
vespa_sim_maps: List[str],
) -> List[Dict[str, str]]:
) -> Generator[Tuple[int, str, str], None, None]:
"""
Generate similarity maps for the given images and query, and return base64-encoded blended images.
Expand All @@ -132,9 +132,11 @@ def gen_similarity_maps(
query_embs (torch.Tensor): Query embeddings.
token_idx_map (dict): Mapping from tokens to their indices.
images (List[Union[Path, str]]): List of image paths or base64-encoded strings.
vespa_sim_maps (List[str]): List of Vespa similarity maps.
Yields:
Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
Returns:
List[Dict[str, str]]: A list where each item is a dictionary mapping tokens to base64-encoded blended images.
"""

start = time.perf_counter()
Expand Down Expand Up @@ -178,19 +180,31 @@ def gen_similarity_maps(
# ... and so on.
# Now turn these into a tensor of same shape as previous similarity map
vespa_sim_map_tensor = torch.zeros(
(len(vespa_sim_maps), query_embs.size(dim=1), vit_config.n_patch_per_dim, vit_config.n_patch_per_dim)
(
len(vespa_sim_maps),
query_embs.size(dim=1),
vit_config.n_patch_per_dim,
vit_config.n_patch_per_dim,
)
)
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
for cell in vespa_sim_map["similarities"]["cells"]:
patch = int(cell["address"]["patch"])
if patch >= processor.image_seq_length:
continue
query_token = int(cell["address"]["querytoken"])
value = cell["value"]
vespa_sim_map_tensor[idx, int(query_token), int(patch) // vit_config.n_patch_per_dim, int(patch) % vit_config.n_patch_per_dim] = value
vespa_sim_map_tensor[
idx,
int(query_token),
int(patch) // vit_config.n_patch_per_dim,
int(patch) % vit_config.n_patch_per_dim,
] = value

# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(vespa_sim_map_tensor)
similarity_map_normalized = normalize_similarity_map_per_query_token(
vespa_sim_map_tensor
)
else:
# Preprocess inputs
print("Computing similarity maps")
Expand Down Expand Up @@ -227,7 +241,9 @@ def gen_similarity_maps(
print(f"Similarity map computation took: {end2 - start2} s")

# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map)
similarity_map_normalized = normalize_similarity_map_per_query_token(
similarity_map
)

# Collect the blended images
start3 = time.perf_counter()
Expand All @@ -242,8 +258,8 @@ def gen_similarity_maps(
# Get the similarity map for this image and the selected token
sim_map = similarity_map_normalized[idx, token_idx, :, :] # Shape: (h, w)

# Move the similarity map to CPU and convert to NumPy array
sim_map_np = sim_map.cpu().numpy()
# Move the similarity map to CPU, convert to float (as BFloat16 not supported by Numpy) and convert to NumPy array
sim_map_np = sim_map.cpu().float().numpy()

# Resize the similarity map to the original image size
sim_map_img = Image.fromarray(sim_map_np)
Expand Down Expand Up @@ -287,11 +303,7 @@ def gen_similarity_maps(

# Store the base64-encoded image
result_per_image[token] = blended_img_base64
results.append(result_per_image)
end3 = time.perf_counter()
print(f"Collecting blended images took: {end3 - start3} s")
print(f"Total heatmap generation took: {end3 - start} s")
return results
yield idx, token, blended_img_base64


def get_query_embeddings_and_token_map(
Expand Down Expand Up @@ -344,7 +356,9 @@ async def query_vespa_default(
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s")
print(
f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
)
open("response.json", "w").write(json.dumps(response.json))
return format_query_results(query, response)

Expand Down Expand Up @@ -492,6 +506,8 @@ def add_sim_maps_to_result(
query: str,
q_embs: Any,
token_to_idx: Dict[str, int],
query_id: str,
result_cache,
) -> Dict[str, Any]:
vit_config = load_vit_config(model)
imgs: List[str] = []
Expand All @@ -503,7 +519,7 @@ def add_sim_maps_to_result(
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
if vespa_sim_map:
vespa_sim_maps.append(vespa_sim_map)
sim_map_imgs = gen_similarity_maps(
sim_map_imgs_generator = gen_similarity_maps(
model=model,
processor=processor,
device=model.device,
Expand All @@ -512,11 +528,16 @@ def add_sim_maps_to_result(
query_embs=q_embs,
token_idx_map=token_to_idx,
images=imgs,
vespa_sim_maps=vespa_sim_maps
vespa_sim_maps=vespa_sim_maps,
)
for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
for token, sim_mapb64 in sim_map_dict.items():
single_result["fields"][f"sim_map_{token}"] = sim_mapb64
for img_idx, token, sim_mapb64 in sim_map_imgs_generator:
print(f"Created sim map for image {img_idx} and token {token}")
result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = sim_mapb64
# Update result_cache with the new sim_map
result_cache.set(query_id, result)
# for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
# for token, sim_mapb64 in sim_map_dict.items():
# single_result["fields"][f"sim_map_{token}"] = sim_mapb64
return result


Expand Down
45 changes: 34 additions & 11 deletions visual-retrieval-colpali/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
vespa_app: Vespa = get_vespa_app()

result_cache = LRUCache(max_size=20) # Each result can be ~10MB
task_cache = LRUCache(
max_size=1000
) # Map from query_id to boolean value - False if not all results are ready.
thread_pool = ThreadPoolExecutor()


Expand Down Expand Up @@ -97,7 +100,17 @@ async def get(request, query: str, nn: bool = True):
)
# Generate a unique query_id based on the query and ranking value
query_id = generate_query_id(query + ranking_value)

# See if results are already in cache
if result_cache.get(query_id):
print(f"Results for query_id {query_id} already in cache")
result = result_cache.get(query_id)
search_results = get_results_children(result)
# If task is completed, return the results, but no query_id
if task_cache.get(query_id):
return SearchResult(search_results, None)
# If task is not completed, return the results with query_id
return SearchResult(search_results, query_id)
task_cache.set(query_id, False)
# Fetch model and processor
manager = ModelManager.get_instance()
model = manager.model
Expand All @@ -116,19 +129,26 @@ async def get(request, query: str, nn: bool = True):
ranking=ranking_value,
)
end = time.perf_counter()
print(f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds")
print(
f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
)
# Start generating the similarity map in the background
asyncio.create_task(
generate_similarity_map(
model, processor, query, q_embs, token_to_idx, result, query_id
)
)
search_results = get_results_children(result)
return SearchResult(search_results, query_id)


def get_results_children(result):
search_results = (
result["root"]["children"]
if "root" in result and "children" in result["root"]
else []
)
return SearchResult(search_results, query_id)
return search_results


async def generate_similarity_map(
Expand All @@ -143,22 +163,25 @@ async def generate_similarity_map(
query=query,
q_embs=q_embs,
token_to_idx=token_to_idx,
query_id=query_id,
result_cache=result_cache,
)
sim_map_result = await loop.run_in_executor(thread_pool, sim_map_task)
result_cache.set(query_id, sim_map_result)
task_cache.set(query_id, True)


@app.get("/updated_search_results")
async def updated_search_results(query_id: str):
data = result_cache.get(query_id)
if data is None:
result = result_cache.get(query_id)
if result is None:
return HTMLResponse(status_code=204)
search_results = (
data["root"]["children"]
if "root" in data and "children" in data["root"]
else []
)
updated_content = SearchResult(results=search_results, query_id=None)
search_results = get_results_children(result)
# Check if task is completed - Stop polling if it is
if task_cache.get(query_id):
updated_content = SearchResult(results=search_results, query_id=None)
else:
updated_content = SearchResult(results=search_results, query_id=query_id)
return updated_content


Expand Down

0 comments on commit 579af68

Please sign in to comment.