Skip to content

Commit

Permalink
refactor(images): use lru_cache for image functions
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Sep 4, 2024
1 parent 71759f5 commit 30b700f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
46 changes: 24 additions & 22 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
import base64
import io
from functools import lru_cache
from PIL import Image
from trame.app import get_server
from nrtk_explorer.app.images.image_ids import (
Expand All @@ -16,6 +17,9 @@
from nrtk_explorer.library.coco_utils import partition


IMAGE_CACHE_SIZE = 50


def convert_to_base64(img: Image.Image) -> str:
"""Convert image to base64 string"""
buf = io.BytesIO()
Expand Down Expand Up @@ -63,43 +67,43 @@ def clear(self):
state, context, ctrl = server.state, server.context, server.controller


class RefCountedState:
def __init__(self, key, value):
self.key = key
self.value = value

def __del__(self):
delete_state(state, self.key)


# syncs trame state
def delete_from_state(key: str):
delete_state(state, key)


image_cache = BufferCache(100, delete_from_state)


@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def get_image(dataset_id: str):
cached_image = image_cache.get_item(dataset_id)
if cached_image is not None:
return cached_image

image_path = server.controller.get_image_fpath(int(dataset_id))
image = Image.open(image_path)

image_cache.add_item(dataset_id, image)
return image


def get_transformed_image(transform: ImageTransform, dataset_id: str):
@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def get_cached_transformed_image(transform: ImageTransform, dataset_id: str):
key = dataset_id_to_transformed_image_id(dataset_id)
cached_image = image_cache.get_item(key)
if cached_image is not None:
return cached_image

original = get_image(dataset_id)
transformed = transform.execute(original)
if original.size != transformed.size:
# Resize so pixel-wise annotation similarity score works
transformed = transformed.resize(original.size)

image_cache.add_item(key, transformed)

state[key] = convert_to_base64(transformed)
return RefCountedState(key, transformed)


return transformed
def get_transformed_image(transform: ImageTransform, dataset_id: str):
return get_cached_transformed_image(transform, dataset_id).value


def add_annotation_to_state(image_id: str, annotations: Any):
Expand Down Expand Up @@ -188,7 +192,8 @@ def get_image_state_keys(dataset_id: str):

@state.change("current_dataset")
def clear_all(**kwargs):
image_cache.clear()
get_image.cache_clear()
get_cached_transformed_image.cache_clear()
annotation_cache.clear()


Expand All @@ -200,14 +205,11 @@ def init_state(old, new):
new_ids = set(new)
to_clean = old_ids - new_ids
for id in to_clean:
image_cache.clear_item(id) # original image
annotation_cache.clear_item(id) # ground truth
annotation_cache.clear_item(dataset_id_to_image_id(id)) # original image detection
keys = get_image_state_keys(id)
image_cache.clear_item(keys["transformed_image"])
annotation_cache.clear_item(keys["transformed_image"])
for key in keys.values():
delete_state(state, key)
delete_state(state, keys["meta_id"])

# create reactive annotation variables so ImageDetection component has live Refs
for id in new:
Expand All @@ -221,7 +223,6 @@ def init_state(old, new):
def clear_transformed(**kwargs):
for id in state.dataset_ids:
transformed_image_id = dataset_id_to_transformed_image_id(id)
image_cache.clear_item(transformed_image_id)
annotation_cache.clear_item(transformed_image_id)
update_image_meta(
state,
Expand All @@ -231,6 +232,7 @@ def clear_transformed(**kwargs):
"ground_truth_to_transformed_detection_score": 0,
},
)
get_cached_transformed_image.cache_clear()


ctrl.apply_transform.add(clear_transformed)
4 changes: 2 additions & 2 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, server):
self.state.transforms = [k for k in self._transforms.keys()]
self.state.current_transform = self.state.transforms[0]

### Transform enabled control ###
# Transform enabled control ##
self.state.transform_enabled = True

def update_transform_enabled(**kwargs):
Expand All @@ -89,7 +89,7 @@ def transform_became_enabled(old, new):
change_checker(self.state, "transform_enabled", transform_became_enabled)(
self.schedule_transformed_images
)
### end Transform enabled control ###
# end Transform enabled control ##

self.server.controller.add("on_server_ready")(self.on_server_ready)
self.server.controller.apply_transform.add(self.schedule_transformed_images)
Expand Down

0 comments on commit 30b700f

Please sign in to comment.