Skip to content

Commit

Permalink
refactor(annotations): move logic from images.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Sep 6, 2024
1 parent f2ee488 commit d54172f
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 130 deletions.
129 changes: 129 additions & 0 deletions src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Any, Callable, Dict, Sequence
from collections import OrderedDict
from functools import lru_cache, partial
from PIL import Image
from nrtk_explorer.library.object_detector import ObjectDetector
from nrtk_explorer.library.coco_utils import partition


class DeleteCallbackRef:
def __init__(self, del_callback, value):
self.del_callback = del_callback
self.value = value

def __del__(self):
self.del_callback()


ANNOTATION_CACHE_SIZE = 500


def get_annotations_from_dataset(
context, add_to_cache_callback, delete_from_cache_callback, dataset_id: str
):
dataset = context.dataset
annotations = [
annotation
for annotation in dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
add_to_cache_callback(dataset_id, annotations)
with_id = partial(delete_from_cache_callback, dataset_id)
return DeleteCallbackRef(with_id, annotations)


class GroundTruthAnnotations:
def __init__(
self,
context, # for dataset
add_to_cache_callback,
delete_from_cache_callback,
):
with_callbacks = partial(
get_annotations_from_dataset,
context,
add_to_cache_callback,
delete_from_cache_callback,
)
self.get_annotations_for_image = lru_cache(maxsize=ANNOTATION_CACHE_SIZE)(with_callbacks)

def get_annotations(self, dataset_ids: Sequence[str]):
return {
dataset_id: self.get_annotations_for_image(dataset_id).value
for dataset_id in dataset_ids
}

def cache_clear(self):
self.get_annotations_for_image.cache_clear()


class LruCache:
"""Least recently accessed item is removed when the cache is full."""

def __init__(
self,
max_size: int,
on_add_item: Callable[[str, Any], None],
on_clear_item: Callable[[str], None],
):
self.cache: OrderedDict[str, Any] = OrderedDict()
self.max_size = max_size
self.on_add_item = on_add_item
self.on_clear_item = on_clear_item

def add_item(self, key: str, item):
"""Add an item to the cache."""
self.cache[key] = item
self.cache.move_to_end(key)
if len(self.cache) > self.max_size:
oldest = next(iter(self.cache))
self.clear_item(oldest)
self.on_add_item(key, item)

def get_item(self, key: str):
"""Retrieve an item from the cache."""
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None

def clear_item(self, key: str):
"""Remove a specific item from the cache."""
if key in self.cache:
self.on_clear_item(key)
del self.cache[key]

def clear(self):
"""Clear the cache."""
for key in self.cache.keys():
self.on_clear_item(key)
self.cache.clear()


class DetectionAnnotations:
def __init__(
self,
add_to_cache_callback,
delete_from_cache_callback,
):
self.cache = LruCache(
ANNOTATION_CACHE_SIZE, add_to_cache_callback, delete_from_cache_callback
)

def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]):
hits, misses = partition(self.cache.get_item, id_to_image.keys())
cached_predictions = {id: self.cache.get_item(id) for id in hits}

to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
to_detect,
)
for id, annotations in predictions.items():
self.cache.add_item(id, annotations)

predictions.update(**cached_predictions)
# match input order because of scoring code assumptions
return {id: predictions[id] for id in id_to_image.keys()}

def cache_clear(self):
self.cache.clear()
112 changes: 0 additions & 112 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,42 +27,6 @@ def convert_to_base64(img: Image.Image) -> str:
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


class BufferCache:
"""Least recently accessed item is removed when the cache is full."""

def __init__(self, max_size: int, on_clear_item: Callable[[str], None]):
self.cache: OrderedDict[str, Any] = OrderedDict()
self.max_size = max_size
self.on_clear_item = on_clear_item

def add_item(self, key: str, item):
"""Add an item to the cache."""
self.cache[key] = item
self.cache.move_to_end(key)
if len(self.cache) > self.max_size:
oldest = next(iter(self.cache))
self.clear_item(oldest)

def get_item(self, key: str):
"""Retrieve an item from the cache."""
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None

def clear_item(self, key: str):
"""Remove a specific item from the cache."""
if key in self.cache:
self.on_clear_item(key)
del self.cache[key]

def clear(self):
"""Clear the cache."""
for key in self.cache.keys():
self.on_clear_item(key)
self.cache.clear()


class RefCountedState:
def __init__(self, key, value):
self.key = key
Expand All @@ -76,11 +40,6 @@ def __del__(self):
state, context, ctrl = server.state, server.context, server.controller


# syncs trame state
def delete_from_state(key: str):
delete_state(state, key)


@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def get_image(dataset_id: str):
image_path = server.controller.get_image_fpath(int(dataset_id))
Expand Down Expand Up @@ -114,70 +73,6 @@ def delete_annotation_from_state(image_id: str):
delete_state(state, image_id_to_result_id(image_id))


annotation_cache = BufferCache(1000, delete_annotation_from_state)


def prediction_to_annotations(predictions):
annotations = []
for prediction in predictions:
category_id = None
# if no matching category in dataset JSON, category_id will be None
for cat_id, cat in state.annotation_categories.items():
if cat["name"] == prediction["label"]:
category_id = cat_id

bbox = prediction["box"]
annotations.append(
{
"category_id": category_id,
"label": prediction["label"],
"bbox": [
bbox["xmin"],
bbox["ymin"],
bbox["xmax"] - bbox["xmin"],
bbox["ymax"] - bbox["ymin"],
],
}
)
return annotations


def get_annotations(detector: ObjectDetector, id_to_image: Dict[str, Image.Image]):
hits, misses = partition(annotation_cache.get_item, id_to_image.keys())

to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
to_detect,
)
for id, annotations in predictions.items():
annotation_cache.add_item(id, annotations)
add_annotation_to_state(id, prediction_to_annotations(annotations))

predictions.update(**{id: annotation_cache.get_item(id) for id in hits})
# match input order because of scoring code assumptions
return {id: predictions[id] for id in id_to_image.keys()}


def get_ground_truth_annotations(dataset_ids: Sequence[str]):
hits, misses = partition(annotation_cache.get_item, dataset_ids)

annotations = {
dataset_id: [
annotation
for annotation in context.dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
for dataset_id in misses
}

for id, boxes_for_image in annotations.items():
annotation_cache.add_item(id, boxes_for_image)
add_annotation_to_state(id, boxes_for_image)

annotations.update({id: annotation_cache.get_item(id) for id in hits})
return [annotations[dataset_id] for dataset_id in dataset_ids]


def get_image_state_keys(dataset_id: str):
return {
"ground_truth": image_id_to_result_id(dataset_id),
Expand All @@ -194,7 +89,6 @@ def get_image_state_keys(dataset_id: str):
def clear_all(**kwargs):
get_image.cache_clear()
get_cached_transformed_image.cache_clear()
annotation_cache.clear()


@change_checker(state, "dataset_ids")
Expand All @@ -205,10 +99,6 @@ def init_state(old, new):
new_ids = set(new)
to_clean = old_ids - new_ids
for id in to_clean:
annotation_cache.clear_item(id) # ground truth
annotation_cache.clear_item(dataset_id_to_image_id(id)) # original image detection
keys = get_image_state_keys(id)
annotation_cache.clear_item(keys["transformed_image"])
delete_state(state, keys["meta_id"])

# create reactive annotation variables so ImageDetection component has live Refs
Expand All @@ -222,8 +112,6 @@ def init_state(old, new):

def clear_transformed(**kwargs):
for id in state.dataset_ids:
transformed_image_id = dataset_id_to_transformed_image_id(id)
annotation_cache.clear_item(transformed_image_id)
update_image_meta(
state,
id,
Expand Down
87 changes: 87 additions & 0 deletions src/nrtk_explorer/app/images/stateful_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from functools import partial
from typing import Any, Union, Callable
from .annotations import GroundTruthAnnotations, DetectionAnnotations
from trame.decorators import TrameApp, change
from nrtk_explorer.app.images.image_ids import (
image_id_to_result_id,
)
from nrtk_explorer.app.trame_utils import delete_state


def add_annotation_to_state(state: Any, image_id: str, annotations: Any):
state[image_id_to_result_id(image_id)] = annotations


def delete_annotation_from_state(state: Any, image_id: str):
delete_state(state, image_id_to_result_id(image_id))


def prediction_to_annotations(state, predictions):
annotations = []
for prediction in predictions:
category_id = None
# if no matching category in dataset JSON, category_id will be None
for cat_id, cat in state.annotation_categories.items():
if cat["name"] == prediction["label"]:
category_id = cat_id

bbox = prediction["box"]
annotations.append(
{
"category_id": category_id,
"label": prediction["label"],
"bbox": [
bbox["xmin"],
bbox["ymin"],
bbox["xmax"] - bbox["xmin"],
bbox["ymax"] - bbox["ymin"],
],
}
)
return annotations


def add_prediction_to_state(state: Any, image_id: str, prediction: Any):
state[image_id_to_result_id(image_id)] = prediction_to_annotations(state, prediction)


AnnotationsFactoryConstructorType = Union[
Callable[[Callable, Callable], GroundTruthAnnotations],
Callable[[Callable, Callable], DetectionAnnotations],
]


@TrameApp()
class StatefulAnnotations:
def __init__(
self,
annotations_factory_constructor: AnnotationsFactoryConstructorType,
server,
add_to_cache_callback=None,
):
self.server = server
state = self.server.state
add_to_cache_callback = add_to_cache_callback or partial(add_annotation_to_state, state)
delete_from_cache_callback = partial(delete_annotation_from_state, state)
self.annotations_factory = annotations_factory_constructor(
add_to_cache_callback, delete_from_cache_callback
)

@change("current_dataset")
def _on_dataset(self, **kwargs):
self.annotations_factory.cache_clear()


def make_stateful_annotations(server):
return StatefulAnnotations(
partial(GroundTruthAnnotations, server.context),
server,
)


def make_stateful_predictor(server):
return StatefulAnnotations(
DetectionAnnotations,
server,
add_to_cache_callback=partial(add_prediction_to_state, server.state),
)
Loading

0 comments on commit d54172f

Please sign in to comment.