Skip to content

Commit

Permalink
refactor(annotations): move logic from images.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Sep 11, 2024
1 parent f2ee488 commit d3bd004
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 192 deletions.
5 changes: 5 additions & 0 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def __init__(self, server=None):
self.state.dataset_ids = []
self.state.hovered_id = None

def clear_hovered(**kwargs):
self.state.hovered_id = None

self.state.change("dataset_ids")(clear_hovered)

self._build_ui()

def on_server_ready(self, *args, **kwargs):
Expand Down
129 changes: 129 additions & 0 deletions src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Any, Callable, Dict, Sequence
from collections import OrderedDict
from functools import lru_cache, partial
from PIL import Image
from nrtk_explorer.library.object_detector import ObjectDetector
from nrtk_explorer.library.coco_utils import partition


ANNOTATION_CACHE_SIZE = 500


class DeleteCallbackRef:
def __init__(self, del_callback, value):
self.del_callback = del_callback
self.value = value

def __del__(self):
self.del_callback()


def get_annotations_from_dataset(
context, add_to_cache_callback, delete_from_cache_callback, dataset_id: str
):
dataset = context.dataset
annotations = [
annotation
for annotation in dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
add_to_cache_callback(dataset_id, annotations)
with_id = partial(delete_from_cache_callback, dataset_id)
return DeleteCallbackRef(with_id, annotations)


class GroundTruthAnnotations:
def __init__(
self,
context, # for dataset
add_to_cache_callback,
delete_from_cache_callback,
):
with_callbacks = partial(
get_annotations_from_dataset,
context,
add_to_cache_callback,
delete_from_cache_callback,
)
self.get_annotations_for_image = lru_cache(maxsize=ANNOTATION_CACHE_SIZE)(with_callbacks)

def get_annotations(self, dataset_ids: Sequence[str]):
return {
dataset_id: self.get_annotations_for_image(dataset_id).value
for dataset_id in dataset_ids
}

def cache_clear(self):
self.get_annotations_for_image.cache_clear()


class LruCache:
"""Least recently accessed item is removed when the cache is full."""

def __init__(
self,
max_size: int,
on_add_item: Callable[[str, Any], None],
on_clear_item: Callable[[str], None],
):
self.cache: OrderedDict[str, Any] = OrderedDict()
self.max_size = max_size
self.on_add_item = on_add_item
self.on_clear_item = on_clear_item

def add_item(self, key: str, item):
"""Add an item to the cache."""
self.cache[key] = item
self.cache.move_to_end(key)
if len(self.cache) > self.max_size:
oldest = next(iter(self.cache))
self.clear_item(oldest)
self.on_add_item(key, item)

def get_item(self, key: str):
"""Retrieve an item from the cache."""
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None

def clear_item(self, key: str):
"""Remove a specific item from the cache."""
if key in self.cache:
self.on_clear_item(key)
del self.cache[key]

def clear(self):
"""Clear the cache."""
for key in self.cache.keys():
self.on_clear_item(key)
self.cache.clear()


class DetectionAnnotations:
def __init__(
self,
add_to_cache_callback,
delete_from_cache_callback,
):
self.cache = LruCache(
ANNOTATION_CACHE_SIZE, add_to_cache_callback, delete_from_cache_callback
)

def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]):
hits, misses = partition(self.cache.get_item, id_to_image.keys())
cached_predictions = {id: self.cache.get_item(id) for id in hits}

to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
to_detect,
)
for id, annotations in predictions.items():
self.cache.add_item(id, annotations)

predictions.update(**cached_predictions)
# match input order because of scoring code assumptions
return {id: predictions[id] for id in id_to_image.keys()}

def cache_clear(self):
self.cache.clear()
15 changes: 15 additions & 0 deletions src/nrtk_explorer/app/images/image_ids.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from nrtk_explorer.app.images.image_meta import dataset_id_to_meta


def image_id_to_dataset_id(image_id: str):
return image_id.split("_")[-1]

Expand All @@ -16,3 +19,15 @@ def image_id_to_result_id(image_id: str):

def is_transformed(image_id: str):
return image_id.startswith("transformed_img_")


def get_image_state_keys(dataset_id: str):
return {
"ground_truth": image_id_to_result_id(dataset_id),
"original_image_detection": image_id_to_result_id(dataset_id_to_image_id(dataset_id)),
"transformed_image": dataset_id_to_transformed_image_id(dataset_id),
"transformed_image_detection": image_id_to_result_id(
dataset_id_to_transformed_image_id(dataset_id)
),
"meta_id": dataset_id_to_meta(dataset_id),
}
173 changes: 3 additions & 170 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
from typing import Any, Callable, Dict, Sequence
from collections import OrderedDict
import base64
import io
from functools import lru_cache
from PIL import Image
from trame.app import get_server
from nrtk_explorer.app.images.image_ids import (
image_id_to_result_id,
dataset_id_to_image_id,
dataset_id_to_transformed_image_id,
)
from nrtk_explorer.app.images.image_meta import dataset_id_to_meta, update_image_meta
from nrtk_explorer.app.trame_utils import delete_state, change_checker
from nrtk_explorer.library.object_detector import ObjectDetector
from nrtk_explorer.app.trame_utils import delete_state
from nrtk_explorer.library.transforms import ImageTransform
from nrtk_explorer.library.coco_utils import partition


IMAGE_CACHE_SIZE = 50
Expand All @@ -27,40 +20,8 @@ def convert_to_base64(img: Image.Image) -> str:
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


class BufferCache:
"""Least recently accessed item is removed when the cache is full."""

def __init__(self, max_size: int, on_clear_item: Callable[[str], None]):
self.cache: OrderedDict[str, Any] = OrderedDict()
self.max_size = max_size
self.on_clear_item = on_clear_item

def add_item(self, key: str, item):
"""Add an item to the cache."""
self.cache[key] = item
self.cache.move_to_end(key)
if len(self.cache) > self.max_size:
oldest = next(iter(self.cache))
self.clear_item(oldest)

def get_item(self, key: str):
"""Retrieve an item from the cache."""
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None

def clear_item(self, key: str):
"""Remove a specific item from the cache."""
if key in self.cache:
self.on_clear_item(key)
del self.cache[key]

def clear(self):
"""Clear the cache."""
for key in self.cache.keys():
self.on_clear_item(key)
self.cache.clear()
server = get_server()
state, context, ctrl = server.state, server.context, server.controller


class RefCountedState:
Expand All @@ -72,15 +33,6 @@ def __del__(self):
delete_state(state, self.key)


server = get_server()
state, context, ctrl = server.state, server.context, server.controller


# syncs trame state
def delete_from_state(key: str):
delete_state(state, key)


@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def get_image(dataset_id: str):
image_path = server.controller.get_image_fpath(int(dataset_id))
Expand All @@ -106,132 +58,13 @@ def get_transformed_image(transform: ImageTransform, dataset_id: str):
return get_cached_transformed_image(transform, dataset_id).value


def add_annotation_to_state(image_id: str, annotations: Any):
state[image_id_to_result_id(image_id)] = annotations


def delete_annotation_from_state(image_id: str):
delete_state(state, image_id_to_result_id(image_id))


annotation_cache = BufferCache(1000, delete_annotation_from_state)


def prediction_to_annotations(predictions):
annotations = []
for prediction in predictions:
category_id = None
# if no matching category in dataset JSON, category_id will be None
for cat_id, cat in state.annotation_categories.items():
if cat["name"] == prediction["label"]:
category_id = cat_id

bbox = prediction["box"]
annotations.append(
{
"category_id": category_id,
"label": prediction["label"],
"bbox": [
bbox["xmin"],
bbox["ymin"],
bbox["xmax"] - bbox["xmin"],
bbox["ymax"] - bbox["ymin"],
],
}
)
return annotations


def get_annotations(detector: ObjectDetector, id_to_image: Dict[str, Image.Image]):
hits, misses = partition(annotation_cache.get_item, id_to_image.keys())

to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
to_detect,
)
for id, annotations in predictions.items():
annotation_cache.add_item(id, annotations)
add_annotation_to_state(id, prediction_to_annotations(annotations))

predictions.update(**{id: annotation_cache.get_item(id) for id in hits})
# match input order because of scoring code assumptions
return {id: predictions[id] for id in id_to_image.keys()}


def get_ground_truth_annotations(dataset_ids: Sequence[str]):
hits, misses = partition(annotation_cache.get_item, dataset_ids)

annotations = {
dataset_id: [
annotation
for annotation in context.dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
for dataset_id in misses
}

for id, boxes_for_image in annotations.items():
annotation_cache.add_item(id, boxes_for_image)
add_annotation_to_state(id, boxes_for_image)

annotations.update({id: annotation_cache.get_item(id) for id in hits})
return [annotations[dataset_id] for dataset_id in dataset_ids]


def get_image_state_keys(dataset_id: str):
return {
"ground_truth": image_id_to_result_id(dataset_id),
"original_image_detection": image_id_to_result_id(dataset_id_to_image_id(dataset_id)),
"transformed_image": dataset_id_to_transformed_image_id(dataset_id),
"transformed_image_detection": image_id_to_result_id(
dataset_id_to_transformed_image_id(dataset_id)
),
"meta_id": dataset_id_to_meta(dataset_id),
}


@state.change("current_dataset")
def clear_all(**kwargs):
get_image.cache_clear()
get_cached_transformed_image.cache_clear()
annotation_cache.clear()


@change_checker(state, "dataset_ids")
def init_state(old, new):
if old is not None:
# clean old ids that are not in new
old_ids = set(old)
new_ids = set(new)
to_clean = old_ids - new_ids
for id in to_clean:
annotation_cache.clear_item(id) # ground truth
annotation_cache.clear_item(dataset_id_to_image_id(id)) # original image detection
keys = get_image_state_keys(id)
annotation_cache.clear_item(keys["transformed_image"])
delete_state(state, keys["meta_id"])

# create reactive annotation variables so ImageDetection component has live Refs
for id in new:
keys = get_image_state_keys(id)
for key in keys.values():
if not state.has(key):
state[key] = None
state.hovered_id = None


def clear_transformed(**kwargs):
for id in state.dataset_ids:
transformed_image_id = dataset_id_to_transformed_image_id(id)
annotation_cache.clear_item(transformed_image_id)
update_image_meta(
state,
id,
{
"original_detection_to_transformed_detection_score": 0,
"ground_truth_to_transformed_detection_score": 0,
},
)
get_cached_transformed_image.cache_clear()


Expand Down
Loading

0 comments on commit d3bd004

Please sign in to comment.