-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(annotations): move logic from images.py
- Loading branch information
Showing
4 changed files
with
274 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from typing import Any, Callable, Dict, Sequence | ||
from collections import OrderedDict | ||
from functools import lru_cache, partial | ||
from PIL import Image | ||
from nrtk_explorer.library.object_detector import ObjectDetector | ||
from nrtk_explorer.library.coco_utils import partition | ||
|
||
|
||
class DeleteCallbackRef: | ||
def __init__(self, del_callback, value): | ||
self.del_callback = del_callback | ||
self.value = value | ||
|
||
def __del__(self): | ||
self.del_callback() | ||
|
||
|
||
ANNOTATION_CACHE_SIZE = 500 | ||
|
||
|
||
def get_annotations_from_dataset( | ||
context, add_to_cache_callback, delete_from_cache_callback, dataset_id: str | ||
): | ||
dataset = context.dataset | ||
annotations = [ | ||
annotation | ||
for annotation in dataset.anns.values() | ||
if str(annotation["image_id"]) == dataset_id | ||
] | ||
add_to_cache_callback(dataset_id, annotations) | ||
with_id = partial(delete_from_cache_callback, dataset_id) | ||
return DeleteCallbackRef(with_id, annotations) | ||
|
||
|
||
class GroundTruthAnnotations: | ||
def __init__( | ||
self, | ||
context, # for dataset | ||
add_to_cache_callback, | ||
delete_from_cache_callback, | ||
): | ||
with_callbacks = partial( | ||
get_annotations_from_dataset, | ||
context, | ||
add_to_cache_callback, | ||
delete_from_cache_callback, | ||
) | ||
self.get_annotations_for_image = lru_cache(maxsize=ANNOTATION_CACHE_SIZE)(with_callbacks) | ||
|
||
def get_annotations(self, dataset_ids: Sequence[str]): | ||
return { | ||
dataset_id: self.get_annotations_for_image(dataset_id).value | ||
for dataset_id in dataset_ids | ||
} | ||
|
||
def cache_clear(self): | ||
self.get_annotations_for_image.cache_clear() | ||
|
||
|
||
class LruCache: | ||
"""Least recently accessed item is removed when the cache is full.""" | ||
|
||
def __init__( | ||
self, | ||
max_size: int, | ||
on_add_item: Callable[[str, Any], None], | ||
on_clear_item: Callable[[str], None], | ||
): | ||
self.cache: OrderedDict[str, Any] = OrderedDict() | ||
self.max_size = max_size | ||
self.on_add_item = on_add_item | ||
self.on_clear_item = on_clear_item | ||
|
||
def add_item(self, key: str, item): | ||
"""Add an item to the cache.""" | ||
self.cache[key] = item | ||
self.cache.move_to_end(key) | ||
if len(self.cache) > self.max_size: | ||
oldest = next(iter(self.cache)) | ||
self.clear_item(oldest) | ||
self.on_add_item(key, item) | ||
|
||
def get_item(self, key: str): | ||
"""Retrieve an item from the cache.""" | ||
if key in self.cache: | ||
self.cache.move_to_end(key) | ||
return self.cache[key] | ||
return None | ||
|
||
def clear_item(self, key: str): | ||
"""Remove a specific item from the cache.""" | ||
if key in self.cache: | ||
self.on_clear_item(key) | ||
del self.cache[key] | ||
|
||
def clear(self): | ||
"""Clear the cache.""" | ||
for key in self.cache.keys(): | ||
self.on_clear_item(key) | ||
self.cache.clear() | ||
|
||
|
||
class DetectionAnnotations: | ||
def __init__( | ||
self, | ||
add_to_cache_callback, | ||
delete_from_cache_callback, | ||
): | ||
self.cache = LruCache( | ||
ANNOTATION_CACHE_SIZE, add_to_cache_callback, delete_from_cache_callback | ||
) | ||
|
||
def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]): | ||
hits, misses = partition(self.cache.get_item, id_to_image.keys()) | ||
cached_predictions = {id: self.cache.get_item(id) for id in hits} | ||
|
||
to_detect = {id: id_to_image[id] for id in misses} | ||
predictions = detector.eval( | ||
to_detect, | ||
) | ||
for id, annotations in predictions.items(): | ||
self.cache.add_item(id, annotations) | ||
|
||
predictions.update(**cached_predictions) | ||
# match input order because of scoring code assumptions | ||
return {id: predictions[id] for id in id_to_image.keys()} | ||
|
||
def cache_clear(self): | ||
self.cache.clear() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from functools import partial | ||
from typing import Any, Union, Callable | ||
from .annotations import GroundTruthAnnotations, DetectionAnnotations | ||
from trame.decorators import TrameApp, change | ||
from nrtk_explorer.app.images.image_ids import ( | ||
image_id_to_result_id, | ||
) | ||
from nrtk_explorer.app.trame_utils import delete_state | ||
|
||
|
||
def add_annotation_to_state(state: Any, image_id: str, annotations: Any): | ||
state[image_id_to_result_id(image_id)] = annotations | ||
|
||
|
||
def delete_annotation_from_state(state: Any, image_id: str): | ||
delete_state(state, image_id_to_result_id(image_id)) | ||
|
||
|
||
def prediction_to_annotations(state, predictions): | ||
annotations = [] | ||
for prediction in predictions: | ||
category_id = None | ||
# if no matching category in dataset JSON, category_id will be None | ||
for cat_id, cat in state.annotation_categories.items(): | ||
if cat["name"] == prediction["label"]: | ||
category_id = cat_id | ||
|
||
bbox = prediction["box"] | ||
annotations.append( | ||
{ | ||
"category_id": category_id, | ||
"label": prediction["label"], | ||
"bbox": [ | ||
bbox["xmin"], | ||
bbox["ymin"], | ||
bbox["xmax"] - bbox["xmin"], | ||
bbox["ymax"] - bbox["ymin"], | ||
], | ||
} | ||
) | ||
return annotations | ||
|
||
|
||
def add_prediction_to_state(state: Any, image_id: str, prediction: Any): | ||
state[image_id_to_result_id(image_id)] = prediction_to_annotations(state, prediction) | ||
|
||
|
||
AnnotationsFactoryConstructorType = Union[ | ||
Callable[[Callable, Callable], GroundTruthAnnotations], | ||
Callable[[Callable, Callable], DetectionAnnotations], | ||
] | ||
|
||
|
||
@TrameApp() | ||
class StatefulAnnotations: | ||
def __init__( | ||
self, | ||
annotations_factory_constructor: AnnotationsFactoryConstructorType, | ||
server, | ||
add_to_cache_callback=None, | ||
): | ||
self.server = server | ||
state = self.server.state | ||
add_to_cache_callback = add_to_cache_callback or partial(add_annotation_to_state, state) | ||
delete_from_cache_callback = partial(delete_annotation_from_state, state) | ||
self.annotations_factory = annotations_factory_constructor( | ||
add_to_cache_callback, delete_from_cache_callback | ||
) | ||
|
||
@change("current_dataset") | ||
def _on_dataset(self, **kwargs): | ||
self.annotations_factory.cache_clear() | ||
|
||
|
||
def make_stateful_annotations(server): | ||
return StatefulAnnotations( | ||
partial(GroundTruthAnnotations, server.context), | ||
server, | ||
) | ||
|
||
|
||
def make_stateful_predictor(server): | ||
return StatefulAnnotations( | ||
DetectionAnnotations, | ||
server, | ||
add_to_cache_callback=partial(add_prediction_to_state, server.state), | ||
) |
Oops, something went wrong.